diff options
| author | Chris Lattner <clattner@google.com> | 2018-12-27 21:21:41 -0800 |
|---|---|---|
| committer | jpienaar <jpienaar@google.com> | 2019-03-29 14:42:23 -0700 |
| commit | 5187cfcf03d36fcd9a08adb768d0bc584ef9e50d (patch) | |
| tree | a78a2e7454c02452df8370b107a1c1ed336bad64 /mlir/lib | |
| parent | 3b021d7f2e6bfd42593af76c02d2aa9c26beaaf0 (diff) | |
| download | bcm5719-llvm-5187cfcf03d36fcd9a08adb768d0bc584ef9e50d.tar.gz bcm5719-llvm-5187cfcf03d36fcd9a08adb768d0bc584ef9e50d.zip | |
Merge Operation into OperationInst and standardize nomenclature around
OperationInst. This is a big mechanical patch.
This is step 16/n towards merging instructions and statements, NFC.
PiperOrigin-RevId: 227093712
Diffstat (limited to 'mlir/lib')
42 files changed, 474 insertions, 570 deletions
diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp index 78115b974a1..f3fde8bb95f 100644 --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -478,14 +478,14 @@ bool mlir::getFlattenedAffineExprs( localVarCst); } -/// Returns the sequence of AffineApplyOp OperationStmts operation in +/// Returns the sequence of AffineApplyOp OperationInsts operation in /// 'affineApplyOps', which are reachable via a search starting from 'operands', /// and ending at operands which are not defined by AffineApplyOps. // TODO(andydavis) Add a method to AffineApplyOp which forward substitutes // the AffineApplyOp into any user AffineApplyOps. void mlir::getReachableAffineApplyOps( ArrayRef<Value *> operands, - SmallVectorImpl<OperationStmt *> &affineApplyOps) { + SmallVectorImpl<OperationInst *> &affineApplyOps) { struct State { // The ssa value for this node in the DFS traversal. Value *value; @@ -499,9 +499,9 @@ void mlir::getReachableAffineApplyOps( while (!worklist.empty()) { State &state = worklist.back(); - auto *opStmt = state.value->getDefiningStmt(); - // Note: getDefiningStmt will return nullptr if the operand is not an - // OperationStmt (i.e. ForStmt), which is a terminator for the search. + auto *opStmt = state.value->getDefiningInst(); + // Note: getDefiningInst will return nullptr if the operand is not an + // OperationInst (i.e. ForStmt), which is a terminator for the search. if (opStmt == nullptr || !opStmt->isa<AffineApplyOp>()) { worklist.pop_back(); continue; @@ -531,7 +531,7 @@ void mlir::getReachableAffineApplyOps( // operands of 'valueMap'. void mlir::forwardSubstituteReachableOps(AffineValueMap *valueMap) { // Gather AffineApplyOps reachable from 'indices'. - SmallVector<OperationStmt *, 4> affineApplyOps; + SmallVector<OperationInst *, 4> affineApplyOps; getReachableAffineApplyOps(valueMap->getOperands(), affineApplyOps); // Compose AffineApplyOps in 'affineApplyOps'. for (auto *opStmt : affineApplyOps) { @@ -842,7 +842,7 @@ addMemRefAccessConstraints(const AffineValueMap &srcAccessMap, auto *symbol = operands[i]; assert(symbol->isValidSymbol()); // Check if the symbol is a constant. - if (auto *opStmt = symbol->getDefiningStmt()) { + if (auto *opStmt = symbol->getDefiningInst()) { if (auto constOp = opStmt->dyn_cast<ConstantIndexOp>()) { dependenceDomain->setIdToConstant(valuePosMap.getSymPos(symbol), constOp->getValue()); diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp index bfdaceff7e7..dd564df3017 100644 --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -1269,7 +1269,7 @@ bool FlatAffineConstraints::addForStmtDomain(const ForStmt &forStmt) { addSymbolId(getNumSymbolIds(), const_cast<Value *>(operand)); loc = getNumDimIds() + getNumSymbolIds() - 1; // Check if the symbol is a constant. - if (auto *opStmt = operand->getDefiningStmt()) { + if (auto *opStmt = operand->getDefiningInst()) { if (auto constOp = opStmt->dyn_cast<ConstantIndexOp>()) { setIdToConstant(*operand, constOp->getValue()); } diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp index 7213ba5986a..85af39222c4 100644 --- a/mlir/lib/Analysis/LoopAnalysis.cpp +++ b/mlir/lib/Analysis/LoopAnalysis.cpp @@ -127,7 +127,7 @@ uint64_t mlir::getLargestDivisorOfTripCount(const ForStmt &forStmt) { bool mlir::isAccessInvariant(const Value &iv, const Value &index) { assert(isa<ForStmt>(iv) && "iv must be a ForStmt"); assert(index.getType().isa<IndexType>() && "index must be of IndexType"); - SmallVector<OperationStmt *, 4> affineApplyOps; + SmallVector<OperationInst *, 4> affineApplyOps; getReachableAffineApplyOps({const_cast<Value *>(&index)}, affineApplyOps); if (affineApplyOps.empty()) { @@ -234,13 +234,13 @@ static bool isVectorElement(LoadOrStoreOpPointer memoryOp) { } static bool isVectorTransferReadOrWrite(const Statement &stmt) { - const auto *opStmt = cast<OperationStmt>(&stmt); + const auto *opStmt = cast<OperationInst>(&stmt); return opStmt->isa<VectorTransferReadOp>() || opStmt->isa<VectorTransferWriteOp>(); } using VectorizableStmtFun = - std::function<bool(const ForStmt &, const OperationStmt &)>; + std::function<bool(const ForStmt &, const OperationInst &)>; static bool isVectorizableLoopWithCond(const ForStmt &loop, VectorizableStmtFun isVectorizableStmt) { @@ -265,7 +265,7 @@ static bool isVectorizableLoopWithCond(const ForStmt &loop, auto loadAndStores = matcher::Op(matcher::isLoadOrStore); auto loadAndStoresMatched = loadAndStores.match(forStmt); for (auto ls : loadAndStoresMatched) { - auto *op = cast<OperationStmt>(ls.first); + auto *op = cast<OperationInst>(ls.first); auto load = op->dyn_cast<LoadOp>(); auto store = op->dyn_cast<StoreOp>(); // Only scalar types are considered vectorizable, all load/store must be @@ -285,7 +285,7 @@ static bool isVectorizableLoopWithCond(const ForStmt &loop, bool mlir::isVectorizableLoopAlongFastestVaryingMemRefDim( const ForStmt &loop, unsigned fastestVaryingDim) { VectorizableStmtFun fun( - [fastestVaryingDim](const ForStmt &loop, const OperationStmt &op) { + [fastestVaryingDim](const ForStmt &loop, const OperationInst &op) { auto load = op.dyn_cast<LoadOp>(); auto store = op.dyn_cast<StoreOp>(); return load ? isContiguousAccess(loop, *load, fastestVaryingDim) @@ -297,7 +297,7 @@ bool mlir::isVectorizableLoopAlongFastestVaryingMemRefDim( bool mlir::isVectorizableLoop(const ForStmt &loop) { VectorizableStmtFun fun( // TODO: implement me - [](const ForStmt &loop, const OperationStmt &op) { return true; }); + [](const ForStmt &loop, const OperationInst &op) { return true; }); return isVectorizableLoopWithCond(loop, fun); } @@ -314,7 +314,7 @@ bool mlir::isStmtwiseShiftValid(const ForStmt &forStmt, for (const auto &stmt : *forBody) { // A for or if stmt does not produce any def/results (that are used // outside). - if (const auto *opStmt = dyn_cast<OperationStmt>(&stmt)) { + if (const auto *opStmt = dyn_cast<OperationInst>(&stmt)) { for (unsigned i = 0, e = opStmt->getNumResults(); i < e; ++i) { const Value *result = opStmt->getResult(i); for (const StmtOperand &use : result->getUses()) { diff --git a/mlir/lib/Analysis/MLFunctionMatcher.cpp b/mlir/lib/Analysis/MLFunctionMatcher.cpp index c227aa3fcdd..c03fed5986b 100644 --- a/mlir/lib/Analysis/MLFunctionMatcher.cpp +++ b/mlir/lib/Analysis/MLFunctionMatcher.cpp @@ -200,7 +200,7 @@ namespace mlir { namespace matcher { MLFunctionMatcher Op(FilterFunctionType filter) { - return MLFunctionMatcher(Statement::Kind::Operation, {}, filter); + return MLFunctionMatcher(Statement::Kind::OperationInst, {}, filter); } MLFunctionMatcher If(MLFunctionMatcher child) { @@ -246,7 +246,7 @@ bool isReductionLoop(const Statement &stmt) { }; bool isLoadOrStore(const Statement &stmt) { - const auto *opStmt = dyn_cast<OperationStmt>(&stmt); + const auto *opStmt = dyn_cast<OperationInst>(&stmt); return opStmt && (opStmt->isa<LoadOp>() || opStmt->isa<StoreOp>()); }; diff --git a/mlir/lib/Analysis/MemRefBoundCheck.cpp b/mlir/lib/Analysis/MemRefBoundCheck.cpp index 995bb466fef..1cb039fe00e 100644 --- a/mlir/lib/Analysis/MemRefBoundCheck.cpp +++ b/mlir/lib/Analysis/MemRefBoundCheck.cpp @@ -45,7 +45,7 @@ struct MemRefBoundCheck : public FunctionPass, StmtWalker<MemRefBoundCheck> { // Not applicable to CFG functions. PassResult runOnCFGFunction(CFGFunction *f) override { return success(); } - void visitOperationStmt(OperationStmt *opStmt); + void visitOperationInst(OperationInst *opStmt); static char passID; }; @@ -58,7 +58,7 @@ FunctionPass *mlir::createMemRefBoundCheckPass() { return new MemRefBoundCheck(); } -void MemRefBoundCheck::visitOperationStmt(OperationStmt *opStmt) { +void MemRefBoundCheck::visitOperationInst(OperationInst *opStmt) { if (auto loadOp = opStmt->dyn_cast<LoadOp>()) { boundCheckLoadOrStoreOp(loadOp); } else if (auto storeOp = opStmt->dyn_cast<StoreOp>()) { diff --git a/mlir/lib/Analysis/MemRefDependenceCheck.cpp b/mlir/lib/Analysis/MemRefDependenceCheck.cpp index 7c57a66310a..ec33c619a17 100644 --- a/mlir/lib/Analysis/MemRefDependenceCheck.cpp +++ b/mlir/lib/Analysis/MemRefDependenceCheck.cpp @@ -40,7 +40,7 @@ namespace { /// Checks dependences between all pairs of memref accesses in an MLFunction. struct MemRefDependenceCheck : public FunctionPass, StmtWalker<MemRefDependenceCheck> { - SmallVector<OperationStmt *, 4> loadsAndStores; + SmallVector<OperationInst *, 4> loadsAndStores; explicit MemRefDependenceCheck() : FunctionPass(&MemRefDependenceCheck::passID) {} @@ -48,7 +48,7 @@ struct MemRefDependenceCheck : public FunctionPass, // Not applicable to CFG functions. PassResult runOnCFGFunction(CFGFunction *f) override { return success(); } - void visitOperationStmt(OperationStmt *opStmt) { + void visitOperationInst(OperationInst *opStmt) { if (opStmt->isa<LoadOp>() || opStmt->isa<StoreOp>()) { loadsAndStores.push_back(opStmt); } @@ -66,7 +66,7 @@ FunctionPass *mlir::createMemRefDependenceCheckPass() { // Adds memref access indices 'opIndices' from 'memrefType' to 'access'. static void addMemRefAccessIndices( - llvm::iterator_range<Operation::const_operand_iterator> opIndices, + llvm::iterator_range<OperationInst::const_operand_iterator> opIndices, MemRefType memrefType, MemRefAccess *access) { access->indices.reserve(memrefType.getRank()); for (auto *index : opIndices) { @@ -75,7 +75,7 @@ static void addMemRefAccessIndices( } // Populates 'access' with memref, indices and opstmt from 'loadOrStoreOpStmt'. -static void getMemRefAccess(const OperationStmt *loadOrStoreOpStmt, +static void getMemRefAccess(const OperationInst *loadOrStoreOpStmt, MemRefAccess *access) { access->opStmt = loadOrStoreOpStmt; if (auto loadOp = loadOrStoreOpStmt->dyn_cast<LoadOp>()) { @@ -131,7 +131,7 @@ getDirectionVectorStr(bool ret, unsigned numCommonLoops, unsigned loopNestDepth, // "source" access and all subsequent "destination" accesses in // 'loadsAndStores'. Emits the result of the dependence check as a note with // the source access. -static void checkDependences(ArrayRef<OperationStmt *> loadsAndStores) { +static void checkDependences(ArrayRef<OperationInst *> loadsAndStores) { for (unsigned i = 0, e = loadsAndStores.size(); i < e; ++i) { auto *srcOpStmt = loadsAndStores[i]; MemRefAccess srcAccess; diff --git a/mlir/lib/Analysis/OpStats.cpp b/mlir/lib/Analysis/OpStats.cpp index d9a0edd6d83..cea0c087297 100644 --- a/mlir/lib/Analysis/OpStats.cpp +++ b/mlir/lib/Analysis/OpStats.cpp @@ -38,7 +38,7 @@ struct PrintOpStatsPass : public FunctionPass, StmtWalker<PrintOpStatsPass> { // Process ML functions and operation statments in ML functions. PassResult runOnMLFunction(MLFunction *function) override; - void visitOperationStmt(OperationStmt *stmt); + void visitOperationInst(OperationInst *stmt); // Print summary of op stats. void printSummary(); @@ -69,7 +69,7 @@ PassResult PrintOpStatsPass::runOnCFGFunction(CFGFunction *function) { return success(); } -void PrintOpStatsPass::visitOperationStmt(OperationStmt *stmt) { +void PrintOpStatsPass::visitOperationInst(OperationInst *stmt) { ++opCount[stmt->getName().getStringRef()]; } diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp index b7873f8327f..c06bf4df61e 100644 --- a/mlir/lib/Analysis/SliceAnalysis.cpp +++ b/mlir/lib/Analysis/SliceAnalysis.cpp @@ -52,7 +52,7 @@ void mlir::getForwardSlice(Statement *stmt, return; } - if (auto *opStmt = dyn_cast<OperationStmt>(stmt)) { + if (auto *opStmt = dyn_cast<OperationInst>(stmt)) { assert(opStmt->getNumResults() <= 1 && "NYI: multiple results"); if (opStmt->getNumResults() > 0) { for (auto &u : opStmt->getResult(0)->getUses()) { @@ -102,7 +102,7 @@ void mlir::getBackwardSlice(Statement *stmt, } for (auto *operand : stmt->getOperands()) { - auto *stmt = operand->getDefiningStmt(); + auto *stmt = operand->getDefiningInst(); if (backwardSlice->count(stmt) == 0) { getBackwardSlice(stmt, backwardSlice, filter, /*topLevel=*/false); @@ -156,7 +156,7 @@ struct DFSState { } // namespace static void DFSPostorder(Statement *current, DFSState *state) { - auto *opStmt = cast<OperationStmt>(current); + auto *opStmt = cast<OperationInst>(current); assert(opStmt->getNumResults() <= 1 && "NYI: multi-result"); if (opStmt->getNumResults() > 0) { for (auto &u : opStmt->getResult(0)->getUses()) { diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index e6975ac5d09..a63723b333c 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -145,7 +145,7 @@ Optional<int64_t> MemRefRegion::getBoundingConstantSizeAndShape( // // TODO(bondhugula): extend this to any other memref dereferencing ops // (dma_start, dma_wait). -bool mlir::getMemRefRegion(OperationStmt *opStmt, unsigned loopDepth, +bool mlir::getMemRefRegion(OperationInst *opStmt, unsigned loopDepth, MemRefRegion *region) { OpPointer<LoadOp> loadOp; OpPointer<StoreOp> storeOp; @@ -204,7 +204,7 @@ bool mlir::getMemRefRegion(OperationStmt *opStmt, unsigned loopDepth, auto *symbol = accessValueMap.getOperand(i); assert(symbol->isValidSymbol()); // Check if the symbol is a constant. - if (auto *opStmt = symbol->getDefiningStmt()) { + if (auto *opStmt = symbol->getDefiningInst()) { if (auto constOp = opStmt->dyn_cast<ConstantIndexOp>()) { regionCst->setIdToConstant(*symbol, constOp->getValue()); } @@ -282,7 +282,7 @@ bool mlir::boundCheckLoadOrStoreOp(LoadOrStoreOpPointer loadOrStoreOp, std::is_same<LoadOrStoreOpPointer, OpPointer<StoreOp>>::value, "function argument should be either a LoadOp or a StoreOp"); - OperationStmt *opStmt = cast<OperationStmt>(loadOrStoreOp->getOperation()); + OperationInst *opStmt = cast<OperationInst>(loadOrStoreOp->getOperation()); MemRefRegion region; if (!getMemRefRegion(opStmt, /*loopDepth=*/0, ®ion)) return false; diff --git a/mlir/lib/Analysis/VectorAnalysis.cpp b/mlir/lib/Analysis/VectorAnalysis.cpp index ec19194f2fa..cd9451cd5e9 100644 --- a/mlir/lib/Analysis/VectorAnalysis.cpp +++ b/mlir/lib/Analysis/VectorAnalysis.cpp @@ -104,7 +104,7 @@ Optional<SmallVector<unsigned, 4>> mlir::shapeRatio(VectorType superVectorType, /// header file. static AffineMap makePermutationMap( MLIRContext *context, - llvm::iterator_range<Operation::operand_iterator> indices, + llvm::iterator_range<OperationInst::operand_iterator> indices, const DenseMap<ForStmt *, unsigned> &enclosingLoopToVectorDim) { using functional::makePtrDynCaster; using functional::map; @@ -157,7 +157,7 @@ static SetVector<ForStmt *> getEnclosingForStmts(Statement *stmt) { } AffineMap -mlir::makePermutationMap(OperationStmt *opStmt, +mlir::makePermutationMap(OperationInst *opStmt, const DenseMap<ForStmt *, unsigned> &loopToVectorDim) { DenseMap<ForStmt *, unsigned> enclosingLoopToVectorDim; auto enclosingLoops = getEnclosingForStmts(opStmt); @@ -178,7 +178,7 @@ mlir::makePermutationMap(OperationStmt *opStmt, enclosingLoopToVectorDim); } -bool mlir::matcher::operatesOnStrictSuperVectors(const OperationStmt &opStmt, +bool mlir::matcher::operatesOnStrictSuperVectors(const OperationInst &opStmt, VectorType subVectorType) { // First, extract the vector type and ditinguish between: // a. ops that *must* lower a super-vector (i.e. vector_transfer_read, diff --git a/mlir/lib/Analysis/Verifier.cpp b/mlir/lib/Analysis/Verifier.cpp index e7abb899a11..e1de6191de6 100644 --- a/mlir/lib/Analysis/Verifier.cpp +++ b/mlir/lib/Analysis/Verifier.cpp @@ -51,7 +51,7 @@ namespace { /// class Verifier { public: - bool failure(const Twine &message, const Operation &value) { + bool failure(const Twine &message, const OperationInst &value) { return value.emitError(message); } @@ -62,15 +62,15 @@ public: bool failure(const Twine &message, const BasicBlock &bb) { // Take the location information for the first instruction in the block. if (!bb.empty()) - if (auto *op = dyn_cast<OperationStmt>(&bb.front())) + if (auto *op = dyn_cast<OperationInst>(&bb.front())) return failure(message, *op); // Worst case, fall back to using the function's location. return failure(message, fn); } - bool verifyOperation(const Operation &op); - bool verifyAttribute(Attribute attr, const Operation &op); + bool verifyOperation(const OperationInst &op); + bool verifyAttribute(Attribute attr, const OperationInst &op); protected: explicit Verifier(const Function &fn) : fn(fn) {} @@ -82,7 +82,7 @@ private: } // end anonymous namespace // Check that function attributes are all well formed. -bool Verifier::verifyAttribute(Attribute attr, const Operation &op) { +bool Verifier::verifyAttribute(Attribute attr, const OperationInst &op) { if (!attr.isOrContainsFunction()) return false; @@ -109,9 +109,9 @@ bool Verifier::verifyAttribute(Attribute attr, const Operation &op) { return false; } -/// Check the invariants of the specified operation instruction or statement. -bool Verifier::verifyOperation(const Operation &op) { - if (op.getOperationFunction() != &fn) +/// Check the invariants of the specified operation. +bool Verifier::verifyOperation(const OperationInst &op) { + if (op.getFunction() != &fn) return failure("operation in the wrong function", op); // Check that operands are non-nil and structurally ok. @@ -245,7 +245,7 @@ struct MLFuncVerifier : public Verifier, public StmtWalker<MLFuncVerifier> { MLFuncVerifier(const MLFunction &fn) : Verifier(fn), fn(fn) {} - void visitOperationStmt(OperationStmt *opStmt) { + void visitOperationInst(OperationInst *opStmt) { hadError |= verifyOperation(*opStmt); } @@ -302,14 +302,14 @@ bool MLFuncVerifier::verifyDominance() { if (!liveValues.count(opValue)) { stmt.emitError("operand #" + Twine(operandNo) + " does not dominate this use"); - if (auto *useStmt = opValue->getDefiningStmt()) + if (auto *useStmt = opValue->getDefiningInst()) useStmt->emitNote("operand defined here"); return true; } ++operandNo; } - if (auto *opStmt = dyn_cast<OperationStmt>(&stmt)) { + if (auto *opStmt = dyn_cast<OperationInst>(&stmt)) { // Operations define values, add them to the hash table. for (auto *result : opStmt->getResults()) liveValues.insert(result, true); @@ -344,7 +344,7 @@ bool MLFuncVerifier::verifyReturn() { return failure(missingReturnMsg, fn); const auto &stmt = fn.getBody()->getStatements().back(); - if (const auto *op = dyn_cast<OperationStmt>(&stmt)) { + if (const auto *op = dyn_cast<OperationInst>(&stmt)) { if (!op->isReturn()) return failure(missingReturnMsg, fn); diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index c44ce4d4d6c..9f465ab8507 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -120,10 +120,10 @@ private: void visitStatement(const Statement *stmt); void visitForStmt(const ForStmt *forStmt); void visitIfStmt(const IfStmt *ifStmt); - void visitOperationStmt(const OperationStmt *opStmt); + void visitOperationInst(const OperationInst *opStmt); void visitType(Type type); void visitAttribute(Attribute attr); - void visitOperation(const Operation *op); + void visitOperation(const OperationInst *op); DenseMap<AffineMap, int> affineMapIds; std::vector<AffineMap> affineMapsById; @@ -161,7 +161,7 @@ void ModuleState::visitAttribute(Attribute attr) { } } -void ModuleState::visitOperation(const Operation *op) { +void ModuleState::visitOperation(const OperationInst *op) { // Visit all the types used in the operation. for (auto *operand : op->getOperands()) visitType(operand->getType()); @@ -212,7 +212,7 @@ void ModuleState::visitForStmt(const ForStmt *forStmt) { visitStatement(&childStmt); } -void ModuleState::visitOperationStmt(const OperationStmt *opStmt) { +void ModuleState::visitOperationInst(const OperationInst *opStmt) { for (auto attr : opStmt->getAttrs()) visitAttribute(attr.second); } @@ -223,8 +223,8 @@ void ModuleState::visitStatement(const Statement *stmt) { return visitIfStmt(cast<IfStmt>(stmt)); case Statement::Kind::For: return visitForStmt(cast<ForStmt>(stmt)); - case Statement::Kind::Operation: - return visitOperationStmt(cast<OperationStmt>(stmt)); + case Statement::Kind::OperationInst: + return visitOperationInst(cast<OperationInst>(stmt)); default: return; } @@ -944,8 +944,8 @@ class FunctionPrinter : public ModulePrinter, private OpAsmPrinter { public: FunctionPrinter(const ModulePrinter &other) : ModulePrinter(other) {} - void printOperation(const Operation *op); - void printDefaultOp(const Operation *op); + void printOperation(const OperationInst *op); + void printDefaultOp(const OperationInst *op); // Implement OpAsmPrinter. raw_ostream &getStream() const { return os; } @@ -983,7 +983,7 @@ protected: llvm::raw_svector_ostream specialName(specialNameBuffer); // Give constant integers special names. - if (auto *op = value->getDefiningOperation()) { + if (auto *op = value->getDefiningInst()) { if (auto intOp = op->dyn_cast<ConstantIntOp>()) { // i1 constants get special names. if (intOp->getType().isInteger(1)) { @@ -1111,7 +1111,7 @@ private: }; } // end anonymous namespace -void FunctionPrinter::printOperation(const Operation *op) { +void FunctionPrinter::printOperation(const OperationInst *op) { if (op->getNumResults()) { printValueID(op->getResult(0), /*printResultNo=*/false); os << " = "; @@ -1128,7 +1128,7 @@ void FunctionPrinter::printOperation(const Operation *op) { printDefaultOp(op); } -void FunctionPrinter::printDefaultOp(const Operation *op) { +void FunctionPrinter::printDefaultOp(const OperationInst *op) { os << '"'; printEscapedString(op->getName().getStringRef(), os); os << "\"("; @@ -1172,7 +1172,7 @@ public: void print(const Instruction *inst); - void printSuccessorAndUseList(const Operation *term, unsigned index); + void printSuccessorAndUseList(const OperationInst *term, unsigned index); void printBBName(const BasicBlock *block) { os << "bb" << getBBID(block); } @@ -1302,7 +1302,7 @@ void CFGFunctionPrinter::printBranchOperands(const Range &range) { os << ')'; } -void CFGFunctionPrinter::printSuccessorAndUseList(const Operation *term, +void CFGFunctionPrinter::printSuccessorAndUseList(const OperationInst *term, unsigned index) { printBBName(term->getSuccessor(index)); printBranchOperands(term->getSuccessorOperands(index)); @@ -1331,11 +1331,11 @@ public: // Methods to print ML function statements. void print(const Statement *stmt); - void print(const OperationStmt *stmt); + void print(const OperationInst *stmt); void print(const ForStmt *stmt); void print(const IfStmt *stmt); void print(const StmtBlock *block); - void printSuccessorAndUseList(const Operation *term, unsigned index) { + void printSuccessorAndUseList(const OperationInst *term, unsigned index) { assert(false && "MLFunctions do not have terminators with successors."); } @@ -1371,7 +1371,7 @@ void MLFunctionPrinter::numberValues() { // the first result of the operation statements. struct NumberValuesPass : public StmtWalker<NumberValuesPass> { NumberValuesPass(MLFunctionPrinter *printer) : printer(printer) {} - void visitOperationStmt(OperationStmt *stmt) { + void visitOperationInst(OperationInst *stmt) { if (stmt->getNumResults() != 0) printer->numberValueID(stmt->getResult(0)); } @@ -1421,8 +1421,8 @@ void MLFunctionPrinter::print(const StmtBlock *block) { void MLFunctionPrinter::print(const Statement *stmt) { switch (stmt->getKind()) { - case Statement::Kind::Operation: - return print(cast<OperationStmt>(stmt)); + case Statement::Kind::OperationInst: + return print(cast<OperationInst>(stmt)); case Statement::Kind::For: return print(cast<ForStmt>(stmt)); case Statement::Kind::If: @@ -1430,7 +1430,7 @@ void MLFunctionPrinter::print(const Statement *stmt) { } } -void MLFunctionPrinter::print(const OperationStmt *stmt) { +void MLFunctionPrinter::print(const OperationInst *stmt) { os.indent(numSpaces); printOperation(stmt); } @@ -1580,7 +1580,7 @@ void Value::print(raw_ostream &os) const { os << "<block argument>\n"; return; case Value::Kind::StmtResult: - return getDefiningStmt()->print(os); + return getDefiningInst()->print(os); case Value::Kind::ForStmt: return cast<ForStmt>(this)->print(os); } diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index 0d3e54364b3..81a3b7c2950 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -290,7 +290,7 @@ StmtBlock *FuncBuilder::createBlock(StmtBlock *insertBefore) { } /// Create an operation given the fields represented as an OperationState. -OperationStmt *FuncBuilder::createOperation(const OperationState &state) { +OperationInst *FuncBuilder::createOperation(const OperationState &state) { auto *op = OperationInst::create(state.location, state.name, state.operands, state.types, state.attributes, state.successors, context); diff --git a/mlir/lib/IR/BuiltinOps.cpp b/mlir/lib/IR/BuiltinOps.cpp index 50ab254dd76..a87ae6b85f0 100644 --- a/mlir/lib/IR/BuiltinOps.cpp +++ b/mlir/lib/IR/BuiltinOps.cpp @@ -36,8 +36,8 @@ BuiltinDialect::BuiltinDialect(MLIRContext *context) addOperations<AffineApplyOp, BranchOp, CondBranchOp, ConstantOp, ReturnOp>(); } -void mlir::printDimAndSymbolList(Operation::const_operand_iterator begin, - Operation::const_operand_iterator end, +void mlir::printDimAndSymbolList(OperationInst::const_operand_iterator begin, + OperationInst::const_operand_iterator end, unsigned numDims, OpAsmPrinter *p) { *p << '('; p->printOperands(begin, begin + numDims); @@ -188,14 +188,12 @@ void BranchOp::print(OpAsmPrinter *p) const { bool BranchOp::verify() const { // ML functions do not have branching terminators. - if (getOperation()->getOperationFunction()->isML()) + if (getOperation()->getFunction()->isML()) return (emitOpError("cannot occur in a ML function"), true); return false; } -BasicBlock *BranchOp::getDest() const { - return getOperation()->getSuccessor(0); -} +BasicBlock *BranchOp::getDest() { return getOperation()->getSuccessor(0); } void BranchOp::setDest(BasicBlock *block) { return getOperation()->setSuccessor(block, 0); @@ -258,18 +256,18 @@ void CondBranchOp::print(OpAsmPrinter *p) const { bool CondBranchOp::verify() const { // ML functions do not have branching terminators. - if (getOperation()->getOperationFunction()->isML()) + if (getOperation()->getFunction()->isML()) return (emitOpError("cannot occur in a ML function"), true); if (!getCondition()->getType().isInteger(1)) return emitOpError("expected condition type was boolean (i1)"); return false; } -BasicBlock *CondBranchOp::getTrueDest() const { +BasicBlock *CondBranchOp::getTrueDest() { return getOperation()->getSuccessor(trueIndex); } -BasicBlock *CondBranchOp::getFalseDest() const { +BasicBlock *CondBranchOp::getFalseDest() { return getOperation()->getSuccessor(falseIndex); } @@ -399,13 +397,13 @@ void ConstantFloatOp::build(Builder *builder, OperationState *result, ConstantOp::build(builder, result, builder->getFloatAttr(type, value), type); } -bool ConstantFloatOp::isClassFor(const Operation *op) { +bool ConstantFloatOp::isClassFor(const OperationInst *op) { return ConstantOp::isClassFor(op) && op->getResult(0)->getType().isa<FloatType>(); } /// ConstantIntOp only matches values whose result type is an IntegerType. -bool ConstantIntOp::isClassFor(const Operation *op) { +bool ConstantIntOp::isClassFor(const OperationInst *op) { return ConstantOp::isClassFor(op) && op->getResult(0)->getType().isa<IntegerType>(); } @@ -427,7 +425,7 @@ void ConstantIntOp::build(Builder *builder, OperationState *result, } /// ConstantIndexOp only matches values whose result type is Index. -bool ConstantIndexOp::isClassFor(const Operation *op) { +bool ConstantIndexOp::isClassFor(const OperationInst *op) { return ConstantOp::isClassFor(op) && op->getResult(0)->getType().isIndex(); } @@ -470,7 +468,7 @@ void ReturnOp::print(OpAsmPrinter *p) const { } bool ReturnOp::verify() const { - auto *function = cast<OperationStmt>(getOperation())->getFunction(); + auto *function = cast<OperationInst>(getOperation())->getFunction(); // The operand number and types must match the function signature. const auto &results = function->getType().getResults(); diff --git a/mlir/lib/IR/Function.cpp b/mlir/lib/IR/Function.cpp index 62f1dca067d..19b137071f4 100644 --- a/mlir/lib/IR/Function.cpp +++ b/mlir/lib/IR/Function.cpp @@ -161,34 +161,34 @@ bool Function::emitError(const Twine &message) const { // MLFunction implementation. //===----------------------------------------------------------------------===// -const OperationStmt *MLFunction::getReturnStmt() const { - return cast<OperationStmt>(&getBody()->back()); +const OperationInst *MLFunction::getReturnStmt() const { + return cast<OperationInst>(&getBody()->back()); } -OperationStmt *MLFunction::getReturnStmt() { - return cast<OperationStmt>(&getBody()->back()); +OperationInst *MLFunction::getReturnStmt() { + return cast<OperationInst>(&getBody()->back()); } -void MLFunction::walk(std::function<void(OperationStmt *)> callback) { +void MLFunction::walk(std::function<void(OperationInst *)> callback) { struct Walker : public StmtWalker<Walker> { - std::function<void(OperationStmt *)> const &callback; - Walker(std::function<void(OperationStmt *)> const &callback) + std::function<void(OperationInst *)> const &callback; + Walker(std::function<void(OperationInst *)> const &callback) : callback(callback) {} - void visitOperationStmt(OperationStmt *opStmt) { callback(opStmt); } + void visitOperationInst(OperationInst *opStmt) { callback(opStmt); } }; Walker v(callback); v.walk(this); } -void MLFunction::walkPostOrder(std::function<void(OperationStmt *)> callback) { +void MLFunction::walkPostOrder(std::function<void(OperationInst *)> callback) { struct Walker : public StmtWalker<Walker> { - std::function<void(OperationStmt *)> const &callback; - Walker(std::function<void(OperationStmt *)> const &callback) + std::function<void(OperationInst *)> const &callback; + Walker(std::function<void(OperationInst *)> const &callback) : callback(callback) {} - void visitOperationStmt(OperationStmt *opStmt) { callback(opStmt); } + void visitOperationInst(OperationInst *opStmt) { callback(opStmt); } }; Walker v(callback); diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp index da0bc4b1595..abc3e1cfda4 100644 --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -450,7 +450,7 @@ auto MLIRContext::getDiagnosticHandler() const -> DiagnosticHandlerTy { /// This emits a diagnostic using the registered issue handle if present, or /// with the default behavior if not. The MLIR compiler should not generally -/// interact with this, it should use methods on Operation instead. +/// interact with this, it should use methods on OperationInst instead. void MLIRContext::emitDiagnostic(Location location, const llvm::Twine &message, DiagnosticKind kind) const { // Check to see if we are emitting a diagnostic on a fused location. diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index 0526c6ea610..6a9b37560db 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -1,4 +1,4 @@ -//===- Operation.cpp - MLIR Operation Class -------------------------------===// +//===- Operation.cpp - Operation support code -----------------------------===// // // Copyright 2019 The MLIR Authors. // @@ -15,8 +15,6 @@ // limitations under the License. // ============================================================================= -#include "mlir/IR/Operation.h" -#include "AttributeListStorage.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Function.h" #include "mlir/IR/MLIRContext.h" @@ -53,200 +51,6 @@ OperationName OperationName::getFromOpaquePointer(void *pointer) { OpAsmParser::~OpAsmParser() {} //===----------------------------------------------------------------------===// -// Operation class -//===----------------------------------------------------------------------===// - -Operation::Operation(OperationName name, ArrayRef<NamedAttribute> attrs, - Location location, MLIRContext *context) - : Statement(Kind::Operation, location), name(name) { - this->attrs = AttributeListStorage::get(attrs, context); - -#ifndef NDEBUG - for (auto elt : attrs) - assert(elt.second != nullptr && "Attributes cannot have null entries"); -#endif -} - -Operation::~Operation() {} - - -/// Return the function this operation is defined in. -Function *Operation::getOperationFunction() { - return llvm::cast<OperationStmt>(this)->getFunction(); -} - -/// Return the number of results this operation has. -unsigned Operation::getNumResults() const { - return llvm::cast<OperationStmt>(this)->getNumResults(); -} - -/// Return the indicated result. -Value *Operation::getResult(unsigned idx) { - return llvm::cast<OperationStmt>(this)->getResult(idx); -} - -unsigned Operation::getNumSuccessors() const { - assert(isTerminator() && "Only terminators have successors."); - return llvm::cast<OperationStmt>(this)->getNumSuccessors(); -} - -unsigned Operation::getNumSuccessorOperands(unsigned index) const { - assert(isTerminator() && "Only terminators have successors."); - return llvm::cast<OperationStmt>(this)->getNumSuccessorOperands(index); -} -BasicBlock *Operation::getSuccessor(unsigned index) { - assert(isTerminator() && "Only terminators have successors"); - return llvm::cast<OperationStmt>(this)->getSuccessor(index); -} -void Operation::setSuccessor(BasicBlock *block, unsigned index) { - assert(isTerminator() && "Only terminators have successors"); - llvm::cast<OperationStmt>(this)->setSuccessor(block, index); -} - -void Operation::eraseSuccessorOperand(unsigned succIndex, unsigned opIndex) { - assert(isTerminator() && "Only terminators have successors"); - return llvm::cast<OperationStmt>(this)->eraseSuccessorOperand(succIndex, - opIndex); -} -auto Operation::getSuccessorOperands(unsigned index) const - -> llvm::iterator_range<const_operand_iterator> { - assert(isTerminator() && "Only terminators have successors."); - unsigned succOperandIndex = - llvm::cast<OperationStmt>(this)->getSuccessorOperandIndex(index); - return {const_operand_iterator(this, succOperandIndex), - const_operand_iterator(this, succOperandIndex + - getNumSuccessorOperands(index))}; -} -auto Operation::getSuccessorOperands(unsigned index) - -> llvm::iterator_range<operand_iterator> { - assert(isTerminator() && "Only terminators have successors."); - unsigned succOperandIndex = - llvm::cast<OperationStmt>(this)->getSuccessorOperandIndex(index); - return {operand_iterator(this, succOperandIndex), - operand_iterator(this, - succOperandIndex + getNumSuccessorOperands(index))}; -} - -/// Return true if there are no users of any results of this operation. -bool Operation::use_empty() const { - for (auto *result : getResults()) - if (!result->use_empty()) - return false; - return true; -} - -ArrayRef<NamedAttribute> Operation::getAttrs() const { - if (!attrs) - return {}; - return attrs->getElements(); -} - -/// If an attribute exists with the specified name, change it to the new -/// value. Otherwise, add a new attribute with the specified name/value. -void Operation::setAttr(Identifier name, Attribute value) { - assert(value && "attributes may never be null"); - auto origAttrs = getAttrs(); - - SmallVector<NamedAttribute, 8> newAttrs(origAttrs.begin(), origAttrs.end()); - auto *context = getContext(); - - // If we already have this attribute, replace it. - for (auto &elt : newAttrs) - if (elt.first == name) { - elt.second = value; - attrs = AttributeListStorage::get(newAttrs, context); - return; - } - - // Otherwise, add it. - newAttrs.push_back({name, value}); - attrs = AttributeListStorage::get(newAttrs, context); -} - -/// Remove the attribute with the specified name if it exists. The return -/// value indicates whether the attribute was present or not. -auto Operation::removeAttr(Identifier name) -> RemoveResult { - auto origAttrs = getAttrs(); - for (unsigned i = 0, e = origAttrs.size(); i != e; ++i) { - if (origAttrs[i].first == name) { - SmallVector<NamedAttribute, 8> newAttrs; - newAttrs.reserve(origAttrs.size() - 1); - newAttrs.append(origAttrs.begin(), origAttrs.begin() + i); - newAttrs.append(origAttrs.begin() + i + 1, origAttrs.end()); - attrs = AttributeListStorage::get(newAttrs, getContext()); - return RemoveResult::Removed; - } - } - return RemoveResult::NotFound; -} - -/// Emit a note about this operation, reporting up to any diagnostic -/// handlers that may be listening. -void Operation::emitNote(const Twine &message) const { - getContext()->emitDiagnostic(getLoc(), message, - MLIRContext::DiagnosticKind::Note); -} - -/// Emit a warning about this operation, reporting up to any diagnostic -/// handlers that may be listening. -void Operation::emitWarning(const Twine &message) const { - getContext()->emitDiagnostic(getLoc(), message, - MLIRContext::DiagnosticKind::Warning); -} - -/// Emit an error about fatal conditions with this operation, reporting up to -/// any diagnostic handlers that may be listening. This function always returns -/// true. NOTE: This may terminate the containing application, only use when -/// the IR is in an inconsistent state. -bool Operation::emitError(const Twine &message) const { - return getContext()->emitError(getLoc(), message); -} - -/// Emit an error with the op name prefixed, like "'dim' op " which is -/// convenient for verifiers. -bool Operation::emitOpError(const Twine &message) const { - return emitError(Twine('\'') + getName().getStringRef() + "' op " + message); -} - -/// Remove this operation from its parent block and delete it. -void Operation::erase() { - return llvm::cast<OperationStmt>(this)->erase(); -} - -/// Attempt to constant fold this operation with the specified constant -/// operand values. If successful, this returns false and fills in the -/// results vector. If not, this returns true and results is unspecified. -bool Operation::constantFold(ArrayRef<Attribute> operands, - SmallVectorImpl<Attribute> &results) const { - if (auto *abstractOp = getAbstractOperation()) { - // If we have a registered operation definition matching this one, use it to - // try to constant fold the operation. - if (!abstractOp->constantFoldHook(this, operands, results)) - return false; - - // Otherwise, fall back on the dialect hook to handle it. - return abstractOp->dialect.constantFoldHook(this, operands, results); - } - - // If this operation hasn't been registered or doesn't have abstract - // operation, fall back to a dialect which matches the prefix. - auto opName = getName().getStringRef(); - if (auto *dialect = getContext()->getRegisteredDialect(opName)) { - return dialect->constantFoldHook(this, operands, results); - } - - return true; -} - -/// Methods for support type inquiry through isa, cast, and dyn_cast. -bool Operation::classof(const Statement *stmt) { - return stmt->getKind() == Statement::Kind::Operation; -} -bool Operation::classof(const IROperandOwner *ptr) { - return ptr->getKind() == IROperandOwner::Kind::OperationStmt; -} - -//===----------------------------------------------------------------------===// // OpState trait class. //===----------------------------------------------------------------------===// @@ -290,19 +94,20 @@ void OpState::emitNote(const Twine &message) const { // Op Trait implementations //===----------------------------------------------------------------------===// -bool OpTrait::impl::verifyZeroOperands(const Operation *op) { +bool OpTrait::impl::verifyZeroOperands(const OperationInst *op) { if (op->getNumOperands() != 0) return op->emitOpError("requires zero operands"); return false; } -bool OpTrait::impl::verifyOneOperand(const Operation *op) { +bool OpTrait::impl::verifyOneOperand(const OperationInst *op) { if (op->getNumOperands() != 1) return op->emitOpError("requires a single operand"); return false; } -bool OpTrait::impl::verifyNOperands(const Operation *op, unsigned numOperands) { +bool OpTrait::impl::verifyNOperands(const OperationInst *op, + unsigned numOperands) { if (op->getNumOperands() != numOperands) { return op->emitOpError("expected " + Twine(numOperands) + " operands, but found " + @@ -311,7 +116,7 @@ bool OpTrait::impl::verifyNOperands(const Operation *op, unsigned numOperands) { return false; } -bool OpTrait::impl::verifyAtLeastNOperands(const Operation *op, +bool OpTrait::impl::verifyAtLeastNOperands(const OperationInst *op, unsigned numOperands) { if (op->getNumOperands() < numOperands) return op->emitOpError("expected " + Twine(numOperands) + @@ -331,7 +136,7 @@ static Type getTensorOrVectorElementType(Type type) { return type; } -bool OpTrait::impl::verifyOperandsAreIntegerLike(const Operation *op) { +bool OpTrait::impl::verifyOperandsAreIntegerLike(const OperationInst *op) { for (auto *operand : op->getOperands()) { auto type = getTensorOrVectorElementType(operand->getType()); if (!type.isIntOrIndex()) @@ -340,7 +145,7 @@ bool OpTrait::impl::verifyOperandsAreIntegerLike(const Operation *op) { return false; } -bool OpTrait::impl::verifySameTypeOperands(const Operation *op) { +bool OpTrait::impl::verifySameTypeOperands(const OperationInst *op) { // Zero or one operand always have the "same" type. unsigned nOperands = op->getNumOperands(); if (nOperands < 2) @@ -354,25 +159,26 @@ bool OpTrait::impl::verifySameTypeOperands(const Operation *op) { return false; } -bool OpTrait::impl::verifyZeroResult(const Operation *op) { +bool OpTrait::impl::verifyZeroResult(const OperationInst *op) { if (op->getNumResults() != 0) return op->emitOpError("requires zero results"); return false; } -bool OpTrait::impl::verifyOneResult(const Operation *op) { +bool OpTrait::impl::verifyOneResult(const OperationInst *op) { if (op->getNumResults() != 1) return op->emitOpError("requires one result"); return false; } -bool OpTrait::impl::verifyNResults(const Operation *op, unsigned numOperands) { +bool OpTrait::impl::verifyNResults(const OperationInst *op, + unsigned numOperands) { if (op->getNumResults() != numOperands) return op->emitOpError("expected " + Twine(numOperands) + " results"); return false; } -bool OpTrait::impl::verifyAtLeastNResults(const Operation *op, +bool OpTrait::impl::verifyAtLeastNResults(const OperationInst *op, unsigned numOperands) { if (op->getNumResults() < numOperands) return op->emitOpError("expected " + Twine(numOperands) + @@ -401,7 +207,7 @@ static bool verifyShapeMatch(Type type1, Type type2) { return false; } -bool OpTrait::impl::verifySameOperandsAndResultShape(const Operation *op) { +bool OpTrait::impl::verifySameOperandsAndResultShape(const OperationInst *op) { if (op->getNumOperands() == 0 || op->getNumResults() == 0) return true; @@ -419,7 +225,7 @@ bool OpTrait::impl::verifySameOperandsAndResultShape(const Operation *op) { return false; } -bool OpTrait::impl::verifySameOperandsAndResultType(const Operation *op) { +bool OpTrait::impl::verifySameOperandsAndResultType(const OperationInst *op) { if (op->getNumOperands() == 0 || op->getNumResults() == 0) return true; @@ -438,8 +244,8 @@ bool OpTrait::impl::verifySameOperandsAndResultType(const Operation *op) { } static bool verifyBBArguments( - llvm::iterator_range<Operation::const_operand_iterator> operands, - const BasicBlock *destBB, const Operation *op) { + llvm::iterator_range<OperationInst::const_operand_iterator> operands, + const BasicBlock *destBB, const OperationInst *op) { unsigned operandCount = std::distance(operands.begin(), operands.end()); if (operandCount != destBB->getNumArguments()) return op->emitError("branch has " + Twine(operandCount) + @@ -455,9 +261,9 @@ static bool verifyBBArguments( return false; } -static bool verifyTerminatorSuccessors(const Operation *op) { +static bool verifyTerminatorSuccessors(const OperationInst *op) { // Verify that the operands lines up with the BB arguments in the successor. - const Function *fn = op->getOperationFunction(); + const Function *fn = op->getFunction(); for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i) { auto *succ = op->getSuccessor(i); if (succ->getFunction() != fn) @@ -468,17 +274,15 @@ static bool verifyTerminatorSuccessors(const Operation *op) { return false; } -bool OpTrait::impl::verifyIsTerminator(const Operation *op) { +bool OpTrait::impl::verifyIsTerminator(const OperationInst *op) { // Verify that the operation is at the end of the respective parent block. - if (op->getOperationFunction()->isML()) { - auto *stmt = cast<OperationStmt>(op); - StmtBlock *block = stmt->getBlock(); - if (!block || block->getContainingStmt() || &block->back() != stmt) + if (op->getFunction()->isML()) { + StmtBlock *block = op->getBlock(); + if (!block || block->getContainingStmt() || &block->back() != op) return op->emitOpError("must be the last statement in the ML function"); } else { - auto *inst = cast<OperationInst>(op); - const BasicBlock *block = inst->getBlock(); - if (!block || &block->back() != inst) + const BasicBlock *block = op->getBlock(); + if (!block || &block->back() != op) return op->emitOpError( "must be the last instruction in the parent basic block."); } @@ -489,7 +293,7 @@ bool OpTrait::impl::verifyIsTerminator(const Operation *op) { return false; } -bool OpTrait::impl::verifyResultsAreBoolLike(const Operation *op) { +bool OpTrait::impl::verifyResultsAreBoolLike(const OperationInst *op) { for (auto *result : op->getResults()) { auto elementType = getTensorOrVectorElementType(result->getType()); auto intType = elementType.dyn_cast<IntegerType>(); @@ -501,7 +305,7 @@ bool OpTrait::impl::verifyResultsAreBoolLike(const Operation *op) { return false; } -bool OpTrait::impl::verifyResultsAreFloatLike(const Operation *op) { +bool OpTrait::impl::verifyResultsAreFloatLike(const OperationInst *op) { for (auto *result : op->getResults()) { if (!getTensorOrVectorElementType(result->getType()).isa<FloatType>()) return op->emitOpError("requires a floating point type"); @@ -510,7 +314,7 @@ bool OpTrait::impl::verifyResultsAreFloatLike(const Operation *op) { return false; } -bool OpTrait::impl::verifyResultsAreIntegerLike(const Operation *op) { +bool OpTrait::impl::verifyResultsAreIntegerLike(const OperationInst *op) { for (auto *result : op->getResults()) { auto type = getTensorOrVectorElementType(result->getType()); if (!type.isIntOrIndex()) @@ -543,7 +347,7 @@ bool impl::parseBinaryOp(OpAsmParser *parser, OperationState *result) { parser->addTypeToList(type, result->types); } -void impl::printBinaryOp(const Operation *op, OpAsmPrinter *p) { +void impl::printBinaryOp(const OperationInst *op, OpAsmPrinter *p) { *p << op->getName() << ' ' << *op->getOperand(0) << ", " << *op->getOperand(1); p->printOptionalAttrDict(op->getAttrs()); @@ -569,7 +373,7 @@ bool impl::parseCastOp(OpAsmParser *parser, OperationState *result) { parser->addTypeToList(dstType, result->types); } -void impl::printCastOp(const Operation *op, OpAsmPrinter *p) { +void impl::printCastOp(const OperationInst *op, OpAsmPrinter *p) { *p << op->getName() << ' ' << *op->getOperand(0) << " : " << op->getOperand(0)->getType() << " to " << op->getResult(0)->getType(); } diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp index 9e4d8bb180c..8c41d488a8b 100644 --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -58,12 +58,14 @@ void Pattern::anchor() {} // RewritePattern and PatternRewriter implementation //===----------------------------------------------------------------------===// -void RewritePattern::rewrite(Operation *op, std::unique_ptr<PatternState> state, +void RewritePattern::rewrite(OperationInst *op, + std::unique_ptr<PatternState> state, PatternRewriter &rewriter) const { rewrite(op, rewriter); } -void RewritePattern::rewrite(Operation *op, PatternRewriter &rewriter) const { +void RewritePattern::rewrite(OperationInst *op, + PatternRewriter &rewriter) const { llvm_unreachable("need to implement one of the rewrite functions!"); } @@ -77,7 +79,7 @@ PatternRewriter::~PatternRewriter() { /// clients can specify a list of other nodes that this replacement may make /// (perhaps transitively) dead. If any of those ops are dead, this will /// remove them as well. -void PatternRewriter::replaceOp(Operation *op, ArrayRef<Value *> newValues, +void PatternRewriter::replaceOp(OperationInst *op, ArrayRef<Value *> newValues, ArrayRef<Value *> valuesToRemoveIfDead) { // Notify the rewriter subclass that we're about to replace this root. notifyRootReplaced(op); @@ -97,7 +99,8 @@ void PatternRewriter::replaceOp(Operation *op, ArrayRef<Value *> newValues, /// op and newOp are known to have the same number of results, replace the /// uses of op with uses of newOp void PatternRewriter::replaceOpWithResultsOfAnotherOp( - Operation *op, Operation *newOp, ArrayRef<Value *> valuesToRemoveIfDead) { + OperationInst *op, OperationInst *newOp, + ArrayRef<Value *> valuesToRemoveIfDead) { assert(op->getNumResults() == newOp->getNumResults() && "replacement op doesn't match results of original op"); if (op->getNumResults() == 1) @@ -117,7 +120,7 @@ void PatternRewriter::replaceOpWithResultsOfAnotherOp( /// should remove if they are dead at this point. /// void PatternRewriter::updatedRootInPlace( - Operation *op, ArrayRef<Value *> valuesToRemoveIfDead) { + OperationInst *op, ArrayRef<Value *> valuesToRemoveIfDead) { // Notify the rewriter subclass that we're about to replace this root. notifyRootUpdated(op); @@ -132,7 +135,7 @@ void PatternRewriter::updatedRootInPlace( /// Find the highest benefit pattern available in the pattern set for the DAG /// rooted at the specified node. This returns the pattern if found, or null /// if there are no matches. -auto PatternMatcher::findMatch(Operation *op) -> MatchResult { +auto PatternMatcher::findMatch(OperationInst *op) -> MatchResult { // TODO: This is a completely trivial implementation, expand this in the // future. diff --git a/mlir/lib/IR/Statement.cpp b/mlir/lib/IR/Statement.cpp index 8bff23d41ed..19457efa8c3 100644 --- a/mlir/lib/IR/Statement.cpp +++ b/mlir/lib/IR/Statement.cpp @@ -15,6 +15,7 @@ // limitations under the License. // ============================================================================= +#include "AttributeListStorage.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinOps.h" @@ -65,8 +66,8 @@ Statement::~Statement() { /// Destroy this statement or one of its subclasses. void Statement::destroy() { switch (this->getKind()) { - case Kind::Operation: - cast<OperationStmt>(this)->destroy(); + case Kind::OperationInst: + cast<OperationInst>(this)->destroy(); break; case Kind::For: delete cast<ForStmt>(this); @@ -95,7 +96,7 @@ const Value *Statement::getOperand(unsigned idx) const { // it is an induction variable, or it is a result of affine apply operation // with dimension id arguments. bool Value::isValidDim() const { - if (auto *stmt = getDefiningStmt()) { + if (auto *stmt = getDefiningInst()) { // Top level statement or constant operation is ok. if (stmt->getParentStmt() == nullptr || stmt->isa<ConstantOp>()) return true; @@ -113,7 +114,7 @@ bool Value::isValidDim() const { // the top level, or it is a result of affine apply operation with symbol // arguments. bool Value::isValidSymbol() const { - if (auto *stmt = getDefiningStmt()) { + if (auto *stmt = getDefiningInst()) { // Top level statement or constant operation is ok. if (stmt->getParentStmt() == nullptr || stmt->isa<ConstantOp>()) return true; @@ -133,8 +134,8 @@ void Statement::setOperand(unsigned idx, Value *value) { unsigned Statement::getNumOperands() const { switch (getKind()) { - case Kind::Operation: - return cast<OperationStmt>(this)->getNumOperands(); + case Kind::OperationInst: + return cast<OperationInst>(this)->getNumOperands(); case Kind::For: return cast<ForStmt>(this)->getNumOperands(); case Kind::If: @@ -144,8 +145,8 @@ unsigned Statement::getNumOperands() const { MutableArrayRef<StmtOperand> Statement::getStmtOperands() { switch (getKind()) { - case Kind::Operation: - return cast<OperationStmt>(this)->getStmtOperands(); + case Kind::OperationInst: + return cast<OperationInst>(this)->getStmtOperands(); case Kind::For: return cast<ForStmt>(this)->getStmtOperands(); case Kind::If: @@ -177,7 +178,7 @@ bool Statement::emitError(const Twine &message) const { // Returns whether the Statement is a terminator. bool Statement::isTerminator() const { - if (auto *op = dyn_cast<OperationStmt>(this)) + if (auto *op = dyn_cast<OperationInst>(this)) return op->isTerminator(); return false; } @@ -264,11 +265,11 @@ void Statement::dropAllReferences() { } //===----------------------------------------------------------------------===// -// OperationStmt +// OperationInst //===----------------------------------------------------------------------===// -/// Create a new OperationStmt with the specific fields. -OperationStmt *OperationStmt::create(Location location, OperationName name, +/// Create a new OperationInst with the specific fields. +OperationInst *OperationInst::create(Location location, OperationName name, ArrayRef<Value *> operands, ArrayRef<Type> resultTypes, ArrayRef<NamedAttribute> attributes, @@ -285,9 +286,9 @@ OperationStmt *OperationStmt::create(Location location, OperationName name, resultTypes.size(), numSuccessors, numSuccessors, numOperands); void *rawMem = malloc(byteSize); - // Initialize the OperationStmt part of the statement. + // Initialize the OperationInst part of the statement. auto stmt = ::new (rawMem) - OperationStmt(location, name, numOperands, resultTypes.size(), + OperationInst(location, name, numOperands, resultTypes.size(), numSuccessors, attributes, context); // Initialize the results and operands. @@ -355,15 +356,22 @@ OperationStmt *OperationStmt::create(Location location, OperationName name, return stmt; } -OperationStmt::OperationStmt(Location location, OperationName name, +OperationInst::OperationInst(Location location, OperationName name, unsigned numOperands, unsigned numResults, unsigned numSuccessors, ArrayRef<NamedAttribute> attributes, MLIRContext *context) - : Operation(name, attributes, location, context), numOperands(numOperands), - numResults(numResults), numSuccs(numSuccessors) {} + : Statement(Kind::OperationInst, location), numOperands(numOperands), + numResults(numResults), numSuccs(numSuccessors), name(name) { +#ifndef NDEBUG + for (auto elt : attributes) + assert(elt.second != nullptr && "Attributes cannot have null entries"); +#endif -OperationStmt::~OperationStmt() { + this->attrs = AttributeListStorage::get(attributes, context); +} + +OperationInst::~OperationInst() { // Explicitly run the destructors for the operands and results. for (auto &operand : getStmtOperands()) operand.~StmtOperand(); @@ -377,13 +385,27 @@ OperationStmt::~OperationStmt() { successor.~StmtBlockOperand(); } -void OperationStmt::destroy() { - this->~OperationStmt(); +/// Return true if there are no users of any results of this operation. +bool OperationInst::use_empty() const { + for (auto *result : getResults()) + if (!result->use_empty()) + return false; + return true; +} + +ArrayRef<NamedAttribute> OperationInst::getAttrs() const { + if (!attrs) + return {}; + return attrs->getElements(); +} + +void OperationInst::destroy() { + this->~OperationInst(); free(this); } /// Return the context this operation is associated with. -MLIRContext *OperationStmt::getContext() const { +MLIRContext *OperationInst::getContext() const { // If we have a result or operand type, that is a constant time way to get // to the context. if (getNumResults()) @@ -396,9 +418,9 @@ MLIRContext *OperationStmt::getContext() const { return getFunction()->getContext(); } -bool OperationStmt::isReturn() const { return isa<ReturnOp>(); } +bool OperationInst::isReturn() const { return isa<ReturnOp>(); } -void OperationStmt::setSuccessor(BasicBlock *block, unsigned index) { +void OperationInst::setSuccessor(BasicBlock *block, unsigned index) { assert(index < getNumSuccessors()); getBlockOperands()[index].set(block); } @@ -413,6 +435,96 @@ void OperationInst::eraseOperand(unsigned index) { Operands[getNumOperands()].~StmtOperand(); } +auto OperationInst::getSuccessorOperands(unsigned index) const + -> llvm::iterator_range<const_operand_iterator> { + assert(isTerminator() && "Only terminators have successors."); + unsigned succOperandIndex = getSuccessorOperandIndex(index); + return {const_operand_iterator(this, succOperandIndex), + const_operand_iterator(this, succOperandIndex + + getNumSuccessorOperands(index))}; +} +auto OperationInst::getSuccessorOperands(unsigned index) + -> llvm::iterator_range<operand_iterator> { + assert(isTerminator() && "Only terminators have successors."); + unsigned succOperandIndex = getSuccessorOperandIndex(index); + return {operand_iterator(this, succOperandIndex), + operand_iterator(this, + succOperandIndex + getNumSuccessorOperands(index))}; +} + +/// If an attribute exists with the specified name, change it to the new +/// value. Otherwise, add a new attribute with the specified name/value. +void OperationInst::setAttr(Identifier name, Attribute value) { + assert(value && "attributes may never be null"); + auto origAttrs = getAttrs(); + + SmallVector<NamedAttribute, 8> newAttrs(origAttrs.begin(), origAttrs.end()); + auto *context = getContext(); + + // If we already have this attribute, replace it. + for (auto &elt : newAttrs) + if (elt.first == name) { + elt.second = value; + attrs = AttributeListStorage::get(newAttrs, context); + return; + } + + // Otherwise, add it. + newAttrs.push_back({name, value}); + attrs = AttributeListStorage::get(newAttrs, context); +} + +/// Remove the attribute with the specified name if it exists. The return +/// value indicates whether the attribute was present or not. +auto OperationInst::removeAttr(Identifier name) -> RemoveResult { + auto origAttrs = getAttrs(); + for (unsigned i = 0, e = origAttrs.size(); i != e; ++i) { + if (origAttrs[i].first == name) { + SmallVector<NamedAttribute, 8> newAttrs; + newAttrs.reserve(origAttrs.size() - 1); + newAttrs.append(origAttrs.begin(), origAttrs.begin() + i); + newAttrs.append(origAttrs.begin() + i + 1, origAttrs.end()); + attrs = AttributeListStorage::get(newAttrs, getContext()); + return RemoveResult::Removed; + } + } + return RemoveResult::NotFound; +} + +/// Attempt to constant fold this operation with the specified constant +/// operand values. If successful, this returns false and fills in the +/// results vector. If not, this returns true and results is unspecified. +bool OperationInst::constantFold(ArrayRef<Attribute> operands, + SmallVectorImpl<Attribute> &results) const { + if (auto *abstractOp = getAbstractOperation()) { + // If we have a registered operation definition matching this one, use it to + // try to constant fold the operation. + if (!abstractOp->constantFoldHook(llvm::cast<OperationInst>(this), operands, + results)) + return false; + + // Otherwise, fall back on the dialect hook to handle it. + return abstractOp->dialect.constantFoldHook(llvm::cast<OperationInst>(this), + operands, results); + } + + // If this operation hasn't been registered or doesn't have abstract + // operation, fall back to a dialect which matches the prefix. + auto opName = getName().getStringRef(); + if (auto *dialect = getContext()->getRegisteredDialect(opName)) { + return dialect->constantFoldHook(llvm::cast<OperationInst>(this), operands, + results); + } + + return true; +} + +/// Emit an error with the op name prefixed, like "'dim' op " which is +/// convenient for verifiers. +bool OperationInst::emitOpError(const Twine &message) const { + return emitError(Twine('\'') + getName().getStringRef() + "' op " + message); +} + //===----------------------------------------------------------------------===// // ForStmt //===----------------------------------------------------------------------===// @@ -625,7 +737,7 @@ Statement *Statement::clone(DenseMap<const Value *, Value *> &operandMap, SmallVector<Value *, 8> operands; SmallVector<StmtBlock *, 2> successors; - if (auto *opStmt = dyn_cast<OperationStmt>(this)) { + if (auto *opStmt = dyn_cast<OperationInst>(this)) { operands.reserve(getNumOperands() + opStmt->getNumSuccessors()); if (!opStmt->isTerminator()) { @@ -653,8 +765,8 @@ Statement *Statement::clone(DenseMap<const Value *, Value *> &operandMap, operands.push_back(nullptr); // Remap the successors operands. - for (auto &operand : opStmt->getSuccessorOperands(succ)) - operands.push_back(remapOperand(operand.get())); + for (auto *operand : opStmt->getSuccessorOperands(succ)) + operands.push_back(remapOperand(operand)); } } @@ -662,7 +774,7 @@ Statement *Statement::clone(DenseMap<const Value *, Value *> &operandMap, resultTypes.reserve(opStmt->getNumResults()); for (auto *result : opStmt->getResults()) resultTypes.push_back(result->getType()); - auto *newOp = OperationStmt::create(getLoc(), opStmt->getName(), operands, + auto *newOp = OperationInst::create(getLoc(), opStmt->getName(), operands, resultTypes, opStmt->getAttrs(), successors, context); // Remember the mapping of any results. diff --git a/mlir/lib/IR/StmtBlock.cpp b/mlir/lib/IR/StmtBlock.cpp index a50861a3060..cfb09e6bf45 100644 --- a/mlir/lib/IR/StmtBlock.cpp +++ b/mlir/lib/IR/StmtBlock.cpp @@ -100,13 +100,13 @@ void StmtBlock::eraseArgument(unsigned index) { // Terminator management //===----------------------------------------------------------------------===// -OperationStmt *StmtBlock::getTerminator() { +OperationInst *StmtBlock::getTerminator() { if (empty()) return nullptr; // Check if the last instruction is a terminator. auto &backInst = statements.back(); - auto *opStmt = dyn_cast<OperationStmt>(&backInst); + auto *opStmt = dyn_cast<OperationInst>(&backInst); if (!opStmt || !opStmt->isTerminator()) return nullptr; return opStmt; diff --git a/mlir/lib/IR/Value.cpp b/mlir/lib/IR/Value.cpp index db58e126e61..41a6d80e2a2 100644 --- a/mlir/lib/IR/Value.cpp +++ b/mlir/lib/IR/Value.cpp @@ -28,29 +28,13 @@ OperationInst *Value::getDefiningInst() { return nullptr; } -/// If this value is the result of an OperationStmt, return the statement -/// that defines it. -OperationStmt *Value::getDefiningStmt() { - if (auto *result = dyn_cast<StmtResult>(this)) - return result->getOwner(); - return nullptr; -} - -Operation *Value::getDefiningOperation() { - if (auto *inst = getDefiningInst()) - return inst; - if (auto *stmt = getDefiningStmt()) - return stmt; - return nullptr; -} - -/// Return the function that this Valueis defined in. +/// Return the function that this Value is defined in. Function *Value::getFunction() { switch (getKind()) { case Value::Kind::BlockArgument: return cast<BlockArgument>(this)->getFunction(); case Value::Kind::StmtResult: - return getDefiningStmt()->getFunction(); + return getDefiningInst()->getFunction(); case Value::Kind::ForStmt: return cast<ForStmt>(this)->getFunction(); } @@ -73,8 +57,8 @@ void IRObjectWithUseList::replaceAllUsesWith(IRObjectWithUseList *newValue) { /// Return the context this operation is associated with. MLIRContext *IROperandOwner::getContext() const { switch (getKind()) { - case Kind::OperationStmt: - return cast<OperationStmt>(this)->getContext(); + case Kind::OperationInst: + return cast<OperationInst>(this)->getContext(); case Kind::ForStmt: return cast<ForStmt>(this)->getContext(); case Kind::IfStmt: diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index 06495eb81ab..35891f5784b 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -103,7 +103,7 @@ private: namespace { using CreateOperationFunction = - std::function<Operation *(const OperationState &)>; + std::function<OperationInst *(const OperationState &)>; /// This class implement support for parsing global entities like types and /// shared entities like SSA names. It is intended to be subclassed by @@ -1915,8 +1915,10 @@ public: // Operations ParseResult parseOperation(const CreateOperationFunction &createOpFunc); - Operation *parseVerboseOperation(const CreateOperationFunction &createOpFunc); - Operation *parseCustomOperation(const CreateOperationFunction &createOpFunc); + OperationInst * + parseVerboseOperation(const CreateOperationFunction &createOpFunc); + OperationInst * + parseCustomOperation(const CreateOperationFunction &createOpFunc); /// Parse a single operation successor and it's operand list. virtual bool parseSuccessorAndUseList(BasicBlock *&dest, @@ -2184,7 +2186,7 @@ FunctionParser::parseOperation(const CreateOperationFunction &createOpFunc) { return ParseFailure; } - Operation *op; + OperationInst *op; if (getToken().is(Token::bare_identifier) || getToken().isKeyword()) op = parseCustomOperation(createOpFunc); else if (getToken().is(Token::string)) @@ -2220,7 +2222,7 @@ FunctionParser::parseOperation(const CreateOperationFunction &createOpFunc) { return ParseSuccess; } -Operation *FunctionParser::parseVerboseOperation( +OperationInst *FunctionParser::parseVerboseOperation( const CreateOperationFunction &createOpFunc) { // Get location information for the operation. @@ -2516,7 +2518,7 @@ private: }; } // end anonymous namespace. -Operation *FunctionParser::parseCustomOperation( +OperationInst *FunctionParser::parseCustomOperation( const CreateOperationFunction &createOpFunc) { auto opLoc = getToken().getLoc(); auto opName = getTokenSpelling(); @@ -2746,7 +2748,7 @@ ParseResult CFGFunctionParser::parseBasicBlock() { // into. builder.setInsertionPointToEnd(block); - auto createOpFunc = [&](const OperationState &result) -> Operation * { + auto createOpFunc = [&](const OperationState &result) -> OperationInst * { return builder.createOperation(result); }; @@ -3149,7 +3151,7 @@ ParseResult MLFunctionParser::parseElseClause(StmtBlock *elseClause) { /// Parse a list of statements ending with `return` or `}` /// ParseResult MLFunctionParser::parseStatements(StmtBlock *block) { - auto createOpFunc = [&](const OperationState &state) -> Operation * { + auto createOpFunc = [&](const OperationState &state) -> OperationInst * { return builder.createOperation(state); }; diff --git a/mlir/lib/StandardOps/StandardOps.cpp b/mlir/lib/StandardOps/StandardOps.cpp index 7611c6e741b..19a4c8d1afe 100644 --- a/mlir/lib/StandardOps/StandardOps.cpp +++ b/mlir/lib/StandardOps/StandardOps.cpp @@ -56,7 +56,7 @@ struct MemRefCastFolder : public RewritePattern { MemRefCastFolder(StringRef rootOpName, MLIRContext *context) : RewritePattern(rootOpName, 1, context) {} - PatternMatchResult match(Operation *op) const override { + PatternMatchResult match(OperationInst *op) const override { for (auto *operand : op->getOperands()) if (matchPattern(operand, m_Op<MemRefCastOp>())) return matchSuccess(); @@ -64,9 +64,9 @@ struct MemRefCastFolder : public RewritePattern { return matchFailure(); } - void rewrite(Operation *op, PatternRewriter &rewriter) const override { + void rewrite(OperationInst *op, PatternRewriter &rewriter) const override { for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i) - if (auto *memref = op->getOperand(i)->getDefiningOperation()) + if (auto *memref = op->getOperand(i)->getDefiningInst()) if (auto cast = memref->dyn_cast<MemRefCastOp>()) op->setOperand(i, cast->getOperand()); rewriter.updatedRootInPlace(op); @@ -122,7 +122,7 @@ struct SimplifyAddX0 : public RewritePattern { SimplifyAddX0(MLIRContext *context) : RewritePattern(AddIOp::getOperationName(), 1, context) {} - PatternMatchResult match(Operation *op) const override { + PatternMatchResult match(OperationInst *op) const override { auto addi = op->cast<AddIOp>(); if (matchPattern(addi->getOperand(1), m_Zero())) @@ -130,7 +130,7 @@ struct SimplifyAddX0 : public RewritePattern { return matchFailure(); } - void rewrite(Operation *op, PatternRewriter &rewriter) const override { + void rewrite(OperationInst *op, PatternRewriter &rewriter) const override { rewriter.replaceOp(op, op->getOperand(0)); } }; @@ -228,7 +228,7 @@ struct SimplifyAllocConst : public RewritePattern { SimplifyAllocConst(MLIRContext *context) : RewritePattern(AllocOp::getOperationName(), 1, context) {} - PatternMatchResult match(Operation *op) const override { + PatternMatchResult match(OperationInst *op) const override { auto alloc = op->cast<AllocOp>(); // Check to see if any dimensions operands are constants. If so, we can @@ -239,7 +239,7 @@ struct SimplifyAllocConst : public RewritePattern { return matchFailure(); } - void rewrite(Operation *op, PatternRewriter &rewriter) const override { + void rewrite(OperationInst *op, PatternRewriter &rewriter) const override { auto allocOp = op->cast<AllocOp>(); auto memrefType = allocOp->getType(); @@ -258,7 +258,7 @@ struct SimplifyAllocConst : public RewritePattern { newShapeConstants.push_back(dimSize); continue; } - auto *defOp = allocOp->getOperand(dynamicDimPos)->getDefiningOperation(); + auto *defOp = allocOp->getOperand(dynamicDimPos)->getDefiningInst(); OpPointer<ConstantIndexOp> constantIndexOp; if (defOp && (constantIndexOp = defOp->dyn_cast<ConstantIndexOp>())) { // Dynamic shape dimension will be folded. @@ -1105,7 +1105,7 @@ struct SimplifyMulX1 : public RewritePattern { SimplifyMulX1(MLIRContext *context) : RewritePattern(MulIOp::getOperationName(), 1, context) {} - PatternMatchResult match(Operation *op) const override { + PatternMatchResult match(OperationInst *op) const override { auto muli = op->cast<MulIOp>(); if (matchPattern(muli->getOperand(1), m_One())) @@ -1113,7 +1113,7 @@ struct SimplifyMulX1 : public RewritePattern { return matchFailure(); } - void rewrite(Operation *op, PatternRewriter &rewriter) const override { + void rewrite(OperationInst *op, PatternRewriter &rewriter) const override { rewriter.replaceOp(op, op->getOperand(0)); } }; @@ -1308,14 +1308,14 @@ struct SimplifyXMinusX : public RewritePattern { SimplifyXMinusX(MLIRContext *context) : RewritePattern(SubIOp::getOperationName(), 1, context) {} - PatternMatchResult match(Operation *op) const override { + PatternMatchResult match(OperationInst *op) const override { auto subi = op->cast<SubIOp>(); if (subi->getOperand(0) == subi->getOperand(1)) return matchSuccess(); return matchFailure(); } - void rewrite(Operation *op, PatternRewriter &rewriter) const override { + void rewrite(OperationInst *op, PatternRewriter &rewriter) const override { auto subi = op->cast<SubIOp>(); auto result = rewriter.create<ConstantIntOp>(op->getLoc(), 0, subi->getType()); diff --git a/mlir/lib/SuperVectorOps/SuperVectorOps.cpp b/mlir/lib/SuperVectorOps/SuperVectorOps.cpp index 02b4c4674ab..e4243a6de25 100644 --- a/mlir/lib/SuperVectorOps/SuperVectorOps.cpp +++ b/mlir/lib/SuperVectorOps/SuperVectorOps.cpp @@ -86,14 +86,14 @@ void VectorTransferReadOp::build(Builder *builder, OperationState *result, result->addTypes(vectorType); } -llvm::iterator_range<Operation::operand_iterator> +llvm::iterator_range<OperationInst::operand_iterator> VectorTransferReadOp::getIndices() { auto begin = getOperation()->operand_begin() + Offsets::FirstIndexOffset; auto end = begin + getMemRefType().getRank(); return {begin, end}; } -llvm::iterator_range<Operation::const_operand_iterator> +llvm::iterator_range<OperationInst::const_operand_iterator> VectorTransferReadOp::getIndices() const { auto begin = getOperation()->operand_begin() + Offsets::FirstIndexOffset; auto end = begin + getMemRefType().getRank(); @@ -303,14 +303,14 @@ void VectorTransferWriteOp::build(Builder *builder, OperationState *result, builder->getAffineMapAttr(permutationMap)); } -llvm::iterator_range<Operation::operand_iterator> +llvm::iterator_range<OperationInst::operand_iterator> VectorTransferWriteOp::getIndices() { auto begin = getOperation()->operand_begin() + Offsets::FirstIndexOffset; auto end = begin + getMemRefType().getRank(); return {begin, end}; } -llvm::iterator_range<Operation::const_operand_iterator> +llvm::iterator_range<OperationInst::const_operand_iterator> VectorTransferWriteOp::getIndices() const { auto begin = getOperation()->operand_begin() + Offsets::FirstIndexOffset; auto end = begin + getMemRefType().getRank(); diff --git a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp index a4d474dc24a..713aa0b1791 100644 --- a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp @@ -111,8 +111,8 @@ private: /// descriptor and get the pointer to the element indexed by the linearized /// subscript. Return nullptr on errors. llvm::Value *emitMemRefElementAccess( - const Value *memRef, const Operation &op, - llvm::iterator_range<Operation::const_operand_iterator> opIndices); + const Value *memRef, const OperationInst &op, + llvm::iterator_range<OperationInst::const_operand_iterator> opIndices); /// Emit LLVM IR corresponding to the given Alloc `op`. In particular, create /// a Value for the MemRef descriptor, store any dynamic sizes passed to @@ -307,7 +307,7 @@ ModuleLowerer::linearizeSubscripts(ArrayRef<llvm::Value *> indices, // the location of `op` and return true. Return false if the type is supported. // TODO(zinenko): this function should disappear when the conversion fully // supports MemRefs. -static bool checkSupportedMemRefType(MemRefType type, const Operation &op) { +static bool checkSupportedMemRefType(MemRefType type, const OperationInst &op) { if (!type.getAffineMaps().empty()) return op.emitError("NYI: memrefs with affine maps"); if (type.getMemorySpace() != 0) @@ -316,8 +316,8 @@ static bool checkSupportedMemRefType(MemRefType type, const Operation &op) { } llvm::Value *ModuleLowerer::emitMemRefElementAccess( - const Value *memRef, const Operation &op, - llvm::iterator_range<Operation::const_operand_iterator> opIndices) { + const Value *memRef, const OperationInst &op, + llvm::iterator_range<OperationInst::const_operand_iterator> opIndices) { auto type = memRef->getType().dyn_cast<MemRefType>(); assert(type && "expected memRef value to have a MemRef type"); if (checkSupportedMemRefType(type, op)) @@ -425,7 +425,7 @@ ModuleLowerer::emitMemRefDealloc(ConstOpPointer<DeallocOp> deallocOp) { // This forcibly recreates the APFloat with IEEESingle semantics to make sure // LLVM constructs a `float` constant. static llvm::ConstantFP *getFloatConstant(APFloat APvalue, - const Operation &inst, + const OperationInst &inst, llvm::LLVMContext *context) { bool unused; APFloat::opStatus status = APvalue.convert( diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp index 575ae2e1c9b..4b198589e2c 100644 --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -50,10 +50,10 @@ struct CSE : public FunctionPass { }; // TODO(riverriddle) Handle commutative operations. -struct SimpleOperationInfo : public llvm::DenseMapInfo<Operation *> { - static unsigned getHashValue(const Operation *op) { +struct SimpleOperationInfo : public llvm::DenseMapInfo<OperationInst *> { + static unsigned getHashValue(const OperationInst *op) { // Hash the operations based upon their: - // - Operation Name + // - OperationInst Name // - Attributes // - Result Types // - Operands @@ -62,7 +62,7 @@ struct SimpleOperationInfo : public llvm::DenseMapInfo<Operation *> { hash_combine_range(op->result_type_begin(), op->result_type_end()), hash_combine_range(op->operand_begin(), op->operand_end())); } - static bool isEqual(const Operation *lhs, const Operation *rhs) { + static bool isEqual(const OperationInst *lhs, const OperationInst *rhs) { if (lhs == rhs) return true; if (lhs == getTombstoneKey() || lhs == getEmptyKey() || @@ -93,8 +93,8 @@ struct SimpleOperationInfo : public llvm::DenseMapInfo<Operation *> { struct CSEImpl { using AllocatorTy = llvm::RecyclingAllocator< llvm::BumpPtrAllocator, - llvm::ScopedHashTableVal<Operation *, Operation *>>; - using ScopedMapTy = llvm::ScopedHashTable<Operation *, Operation *, + llvm::ScopedHashTableVal<OperationInst *, OperationInst *>>; + using ScopedMapTy = llvm::ScopedHashTable<OperationInst *, OperationInst *, SimpleOperationInfo, AllocatorTy>; /// Erase any operations that were marked as dead during simplification. @@ -104,7 +104,7 @@ struct CSEImpl { } /// Attempt to eliminate a redundant operation. - void simplifyOperation(Operation *op) { + void simplifyOperation(OperationInst *op) { // TODO(riverriddle) We currently only eliminate non side-effecting // operations. if (!op->hasNoSideEffect()) @@ -141,7 +141,7 @@ struct CSEImpl { ScopedMapTy knownValues; /// Operations marked as dead and to be erased. - std::vector<Operation *> opsToErase; + std::vector<OperationInst *> opsToErase; }; /// Common sub-expression elimination for CFG functions. @@ -224,7 +224,7 @@ struct MLCSE : public CSEImpl, StmtWalker<MLCSE> { StmtWalker<MLCSE>::walk(Start, End); } - void visitOperationStmt(OperationStmt *stmt) { simplifyOperation(stmt); } + void visitOperationInst(OperationInst *stmt) { simplifyOperation(stmt); } }; } // end anonymous namespace diff --git a/mlir/lib/Transforms/ComposeAffineMaps.cpp b/mlir/lib/Transforms/ComposeAffineMaps.cpp index 84507b91703..365533561f9 100644 --- a/mlir/lib/Transforms/ComposeAffineMaps.cpp +++ b/mlir/lib/Transforms/ComposeAffineMaps.cpp @@ -42,12 +42,12 @@ namespace { // with no remaining uses are collected and erased after the walk. // TODO(andydavis) Remove this when Chris adds instruction combiner pass. struct ComposeAffineMaps : public FunctionPass, StmtWalker<ComposeAffineMaps> { - std::vector<OperationStmt *> affineApplyOpsToErase; + std::vector<OperationInst *> affineApplyOpsToErase; explicit ComposeAffineMaps() : FunctionPass(&ComposeAffineMaps::passID) {} using StmtListType = llvm::iplist<Statement>; void walk(StmtListType::iterator Start, StmtListType::iterator End); - void visitOperationStmt(OperationStmt *stmt); + void visitOperationInst(OperationInst *stmt); PassResult runOnMLFunction(MLFunction *f) override; using StmtWalker<ComposeAffineMaps>::walk; @@ -72,7 +72,7 @@ void ComposeAffineMaps::walk(StmtListType::iterator Start, } } -void ComposeAffineMaps::visitOperationStmt(OperationStmt *opStmt) { +void ComposeAffineMaps::visitOperationInst(OperationInst *opStmt) { if (auto affineApplyOp = opStmt->dyn_cast<AffineApplyOp>()) { forwardSubstitute(affineApplyOp); bool allUsesEmpty = true; diff --git a/mlir/lib/Transforms/ConstantFold.cpp b/mlir/lib/Transforms/ConstantFold.cpp index b6b1dec7b17..a83e625c240 100644 --- a/mlir/lib/Transforms/ConstantFold.cpp +++ b/mlir/lib/Transforms/ConstantFold.cpp @@ -31,13 +31,14 @@ struct ConstantFold : public FunctionPass, StmtWalker<ConstantFold> { // All constants in the function post folding. SmallVector<Value *, 8> existingConstants; - // Operation statements that were folded and that need to be erased. - std::vector<OperationStmt *> opStmtsToErase; + // Operations that were folded and that need to be erased. + std::vector<OperationInst *> opStmtsToErase; using ConstantFactoryType = std::function<Value *(Attribute, Type)>; - bool foldOperation(Operation *op, SmallVectorImpl<Value *> &existingConstants, + bool foldOperation(OperationInst *op, + SmallVectorImpl<Value *> &existingConstants, ConstantFactoryType constantFactory); - void visitOperationStmt(OperationStmt *stmt); + void visitOperationInst(OperationInst *stmt); void visitForStmt(ForStmt *stmt); PassResult runOnCFGFunction(CFGFunction *f) override; PassResult runOnMLFunction(MLFunction *f) override; @@ -52,7 +53,7 @@ char ConstantFold::passID = 0; /// constants are found, we keep track of them in the existingConstants list. /// /// This returns false if the operation was successfully folded. -bool ConstantFold::foldOperation(Operation *op, +bool ConstantFold::foldOperation(OperationInst *op, SmallVectorImpl<Value *> &existingConstants, ConstantFactoryType constantFactory) { // If this operation is already a constant, just remember it for cleanup @@ -67,7 +68,7 @@ bool ConstantFold::foldOperation(Operation *op, SmallVector<Attribute, 8> operandConstants; for (auto *operand : op->getOperands()) { Attribute operandCst = nullptr; - if (auto *operandOp = operand->getDefiningOperation()) { + if (auto *operandOp = operand->getDefiningInst()) { if (auto operandConstantOp = operandOp->dyn_cast<ConstantOp>()) operandCst = operandConstantOp->getValue(); } @@ -138,8 +139,8 @@ PassResult ConstantFold::runOnCFGFunction(CFGFunction *f) { return success(); } -// Override the walker's operation statement visit for constant folding. -void ConstantFold::visitOperationStmt(OperationStmt *stmt) { +// Override the walker's operation visiter for constant folding. +void ConstantFold::visitOperationInst(OperationInst *stmt) { auto constantFactory = [&](Attribute value, Type type) -> Value * { FuncBuilder builder(stmt); return builder.create<ConstantOp>(stmt->getLoc(), value, type); @@ -172,7 +173,7 @@ PassResult ConstantFold::runOnMLFunction(MLFunction *f) { // around dead constants. Check for them now and remove them. for (auto *cst : existingConstants) { if (cst->use_empty()) - cst->getDefiningStmt()->erase(); + cst->getDefiningInst()->erase(); } return success(); diff --git a/mlir/lib/Transforms/ConvertToCFG.cpp b/mlir/lib/Transforms/ConvertToCFG.cpp index fefe9f700c4..ca158a17e92 100644 --- a/mlir/lib/Transforms/ConvertToCFG.cpp +++ b/mlir/lib/Transforms/ConvertToCFG.cpp @@ -47,14 +47,14 @@ public: void visitForStmt(ForStmt *forStmt); void visitIfStmt(IfStmt *ifStmt); - void visitOperationStmt(OperationStmt *opStmt); + void visitOperationInst(OperationInst *opStmt); private: Value *getConstantIndexValue(int64_t value); void visitStmtBlock(StmtBlock *stmtBlock); Value *buildMinMaxReductionSeq( Location loc, CmpIPredicate predicate, - llvm::iterator_range<Operation::result_iterator> values); + llvm::iterator_range<OperationInst::result_iterator> values); CFGFunction *cfgFunc; FuncBuilder builder; @@ -64,7 +64,7 @@ private: }; } // end anonymous namespace -// Return a vector of OperationStmt's arguments as Values. For each +// Return a vector of OperationInst's arguments as Values. For each // statement operands, represented as Value, lookup its Value conterpart in // the valueRemapping table. static llvm::SmallVector<mlir::Value *, 4> @@ -84,7 +84,7 @@ operandsAs(Statement *opStmt, // remains the same but the values must be updated to be Values. Update the // mapping Value->Value as the conversion is performed. The operation // instruction is appended to current block (end of SESE region). -void FunctionConverter::visitOperationStmt(OperationStmt *opStmt) { +void FunctionConverter::visitOperationInst(OperationInst *opStmt) { // Set up basic operation state (context, name, operands). OperationState state(cfgFunc->getContext(), opStmt->getLoc(), opStmt->getName()); @@ -136,7 +136,7 @@ void FunctionConverter::visitStmtBlock(StmtBlock *stmtBlock) { // recognize as a reduction by the subsequent passes. Value *FunctionConverter::buildMinMaxReductionSeq( Location loc, CmpIPredicate predicate, - llvm::iterator_range<Operation::result_iterator> values) { + llvm::iterator_range<OperationInst::result_iterator> values) { assert(!llvm::empty(values) && "empty min/max chain"); auto valueIt = values.begin(); @@ -600,7 +600,7 @@ void ModuleConverter::replaceReferences() { // operation "op" and containing an MLFunction-typed value with the result of // converting "func" to a CFGFunction. static inline void replaceMLFunctionAttr( - Operation &op, Identifier name, const Function *func, + OperationInst &op, Identifier name, const Function *func, const llvm::DenseMap<MLFunction *, CFGFunction *> &generatedFuncs) { if (!func->isML()) return; diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index cc2ca32421b..ed184dc9421 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -67,7 +67,7 @@ struct DmaGeneration : public FunctionPass, StmtWalker<DmaGeneration> { PassResult runOnMLFunction(MLFunction *f) override; void runOnForStmt(ForStmt *forStmt); - void visitOperationStmt(OperationStmt *opStmt); + void visitOperationInst(OperationInst *opStmt); bool generateDma(const MemRefRegion ®ion, ForStmt *forStmt, uint64_t *sizeInBytes); @@ -108,7 +108,7 @@ FunctionPass *mlir::createDmaGenerationPass(unsigned slowMemorySpace, // Gather regions to promote to buffers in faster memory space. // TODO(bondhugula): handle store op's; only load's handled for now. -void DmaGeneration::visitOperationStmt(OperationStmt *opStmt) { +void DmaGeneration::visitOperationInst(OperationInst *opStmt) { if (auto loadOp = opStmt->dyn_cast<LoadOp>()) { if (loadOp->getMemRefType().getMemorySpace() != slowMemorySpace) return; diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index c86eec3d276..67b36cfda30 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -80,7 +80,7 @@ char LoopFusion::passID = 0; FunctionPass *mlir::createLoopFusionPass() { return new LoopFusion; } -static void getSingleMemRefAccess(OperationStmt *loadOrStoreOpStmt, +static void getSingleMemRefAccess(OperationInst *loadOrStoreOpStmt, MemRefAccess *access) { if (auto loadOp = loadOrStoreOpStmt->dyn_cast<LoadOp>()) { access->memref = loadOp->getMemRef(); @@ -112,8 +112,8 @@ struct FusionCandidate { MemRefAccess dstAccess; }; -static FusionCandidate buildFusionCandidate(OperationStmt *srcStoreOpStmt, - OperationStmt *dstLoadOpStmt) { +static FusionCandidate buildFusionCandidate(OperationInst *srcStoreOpStmt, + OperationInst *dstLoadOpStmt) { FusionCandidate candidate; // Get store access for src loop nest. getSingleMemRefAccess(srcStoreOpStmt, &candidate.srcAccess); @@ -123,7 +123,7 @@ static FusionCandidate buildFusionCandidate(OperationStmt *srcStoreOpStmt, } // Returns the loop depth of the loop nest surrounding 'opStmt'. -static unsigned getLoopDepth(OperationStmt *opStmt) { +static unsigned getLoopDepth(OperationInst *opStmt) { unsigned loopDepth = 0; auto *currStmt = opStmt->getParentStmt(); ForStmt *currForStmt; @@ -141,15 +141,15 @@ namespace { class LoopNestStateCollector : public StmtWalker<LoopNestStateCollector> { public: SmallVector<ForStmt *, 4> forStmts; - SmallVector<OperationStmt *, 4> loadOpStmts; - SmallVector<OperationStmt *, 4> storeOpStmts; + SmallVector<OperationInst *, 4> loadOpStmts; + SmallVector<OperationInst *, 4> storeOpStmts; bool hasIfStmt = false; void visitForStmt(ForStmt *forStmt) { forStmts.push_back(forStmt); } void visitIfStmt(IfStmt *ifStmt) { hasIfStmt = true; } - void visitOperationStmt(OperationStmt *opStmt) { + void visitOperationInst(OperationInst *opStmt) { if (opStmt->isa<LoadOp>()) loadOpStmts.push_back(opStmt); if (opStmt->isa<StoreOp>()) @@ -171,10 +171,10 @@ public: unsigned id; // The top-level statment which is (or contains) loads/stores. Statement *stmt; - // List of load op stmts. - SmallVector<OperationStmt *, 4> loads; + // List of load operations. + SmallVector<OperationInst *, 4> loads; // List of store op stmts. - SmallVector<OperationStmt *, 4> stores; + SmallVector<OperationInst *, 4> stores; Node(unsigned id, Statement *stmt) : id(id), stmt(stmt) {} // Returns the load op count for 'memref'. @@ -312,8 +312,8 @@ public: } // Adds ops in 'loads' and 'stores' to node at 'id'. - void addToNode(unsigned id, const SmallVectorImpl<OperationStmt *> &loads, - const SmallVectorImpl<OperationStmt *> &stores) { + void addToNode(unsigned id, const SmallVectorImpl<OperationInst *> &loads, + const SmallVectorImpl<OperationInst *> &stores) { Node *node = getNode(id); for (auto *loadOpStmt : loads) node->loads.push_back(loadOpStmt); @@ -370,7 +370,7 @@ bool MemRefDependenceGraph::init(MLFunction *f) { } nodes.insert({node.id, node}); } - if (auto *opStmt = dyn_cast<OperationStmt>(&stmt)) { + if (auto *opStmt = dyn_cast<OperationInst>(&stmt)) { if (auto loadOp = opStmt->dyn_cast<LoadOp>()) { // Create graph node for top-level load op. Node node(id++, &stmt); @@ -474,7 +474,7 @@ public: if (!isa<ForStmt>(dstNode->stmt)) continue; - SmallVector<OperationStmt *, 4> loads = dstNode->loads; + SmallVector<OperationInst *, 4> loads = dstNode->loads; while (!loads.empty()) { auto *dstLoadOpStmt = loads.pop_back_val(); auto *memref = dstLoadOpStmt->cast<LoadOp>()->getMemRef(); diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp index 183613a2f69..0a3dd65d1f4 100644 --- a/mlir/lib/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Transforms/LoopUnroll.cpp @@ -120,7 +120,7 @@ PassResult LoopUnroll::runOnMLFunction(MLFunction *f) { return hasInnerLoops; } - bool visitOperationStmt(OperationStmt *opStmt) { return false; } + bool visitOperationInst(OperationInst *opStmt) { return false; } // FIXME: can't use base class method for this because that in turn would // need to use the derived class method above. CRTP doesn't allow it, and diff --git a/mlir/lib/Transforms/LowerVectorTransfers.cpp b/mlir/lib/Transforms/LowerVectorTransfers.cpp index fd23c341903..e2fd8b66e34 100644 --- a/mlir/lib/Transforms/LowerVectorTransfers.cpp +++ b/mlir/lib/Transforms/LowerVectorTransfers.cpp @@ -185,7 +185,7 @@ static void rewriteAsLoops(VectorTransferOpTy *transfer, // case of GPUs. llvm::SmallVector<Value *, 1> newResults = {}; if (std::is_same<VectorTransferOpTy, VectorTransferReadOp>::value) { - b.setInsertionPoint(cast<OperationStmt>(transfer->getOperation())); + b.setInsertionPoint(cast<OperationInst>(transfer->getOperation())); auto *vector = b.create<LoadOp>(transfer->getLoc(), vecView->getResult(), ArrayRef<Value *>{state->zero}) ->getResult(); @@ -193,7 +193,7 @@ static void rewriteAsLoops(VectorTransferOpTy *transfer, } // 6. Free the local buffer. - b.setInsertionPoint(cast<OperationStmt>(transfer->getOperation())); + b.setInsertionPoint(cast<OperationInst>(transfer->getOperation())); b.create<DeallocOp>(transfer->getLoc(), tmpScalarAlloc); // 7. It is now safe to erase the statement. @@ -207,13 +207,14 @@ public: explicit VectorTransferExpander(MLIRContext *context) : MLLoweringPattern(VectorTransferOpTy::getOperationName(), 1, context) {} - PatternMatchResult match(Operation *op) const override { + PatternMatchResult match(OperationInst *op) const override { if (m_Op<VectorTransferOpTy>().match(op)) return matchSuccess(); return matchFailure(); } - void rewriteOpStmt(Operation *op, MLFuncGlobalLoweringState *funcWiseState, + void rewriteOpStmt(OperationInst *op, + MLFuncGlobalLoweringState *funcWiseState, std::unique_ptr<PatternState> opState, MLFuncLoweringRewriter *rewriter) const override { rewriteAsLoops(&*op->dyn_cast<VectorTransferOpTy>(), rewriter, diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index b4d91b2506c..6f033710798 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -246,8 +246,8 @@ static SmallVector<unsigned, 8> delinearize(unsigned linearIndex, return res; } -static OperationStmt * -instantiate(FuncBuilder *b, OperationStmt *opStmt, VectorType hwVectorType, +static OperationInst * +instantiate(FuncBuilder *b, OperationInst *opStmt, VectorType hwVectorType, DenseMap<const Value *, Value *> *substitutionsMap); /// Not all Values belong to a program slice scoped within the immediately @@ -263,7 +263,7 @@ static Value *substitute(Value *v, VectorType hwVectorType, DenseMap<const Value *, Value *> *substitutionsMap) { auto it = substitutionsMap->find(v); if (it == substitutionsMap->end()) { - auto *opStmt = cast<OperationStmt>(v->getDefiningOperation()); + auto *opStmt = cast<OperationInst>(v->getDefiningInst()); if (opStmt->isa<ConstantOp>()) { FuncBuilder b(opStmt); auto *inst = instantiate(&b, opStmt, hwVectorType, substitutionsMap); @@ -272,7 +272,7 @@ static Value *substitute(Value *v, VectorType hwVectorType, assert(res.second && "Insertion failed"); return res.first->second; } - v->getDefiningOperation()->emitError("Missing substitution"); + v->getDefiningInst()->emitError("Missing substitution"); return nullptr; } return it->second; @@ -384,7 +384,7 @@ reindexAffineIndices(FuncBuilder *b, VectorType hwVectorType, /// - constant splat is replaced by constant splat of `hwVectorType`. /// TODO(ntv): add more substitutions on a per-need basis. static SmallVector<NamedAttribute, 1> -materializeAttributes(OperationStmt *opStmt, VectorType hwVectorType) { +materializeAttributes(OperationInst *opStmt, VectorType hwVectorType) { SmallVector<NamedAttribute, 1> res; for (auto a : opStmt->getAttrs()) { if (auto splat = a.second.dyn_cast<SplatElementsAttr>()) { @@ -404,8 +404,8 @@ materializeAttributes(OperationStmt *opStmt, VectorType hwVectorType) { /// substitutionsMap. /// /// If the underlying substitution fails, this fails too and returns nullptr. -static OperationStmt * -instantiate(FuncBuilder *b, OperationStmt *opStmt, VectorType hwVectorType, +static OperationInst * +instantiate(FuncBuilder *b, OperationInst *opStmt, VectorType hwVectorType, DenseMap<const Value *, Value *> *substitutionsMap) { assert(!opStmt->isa<VectorTransferReadOp>() && "Should call the function specialized for VectorTransferReadOp"); @@ -475,7 +475,7 @@ static AffineMap projectedPermutationMap(VectorTransferOpTy *transfer, /// `hwVectorType` int the covering of the super-vector type. For a more /// detailed description of the problem, see the description of /// reindexAffineIndices. -static OperationStmt * +static OperationInst * instantiate(FuncBuilder *b, VectorTransferReadOp *read, VectorType hwVectorType, ArrayRef<unsigned> hwVectorInstance, DenseMap<const Value *, Value *> *substitutionsMap) { @@ -486,7 +486,7 @@ instantiate(FuncBuilder *b, VectorTransferReadOp *read, VectorType hwVectorType, auto cloned = b->create<VectorTransferReadOp>( read->getLoc(), hwVectorType, read->getMemRef(), affineIndices, projectedPermutationMap(read, hwVectorType), read->getPaddingValue()); - return cast<OperationStmt>(cloned->getOperation()); + return cast<OperationInst>(cloned->getOperation()); } /// Creates an instantiated version of `write` for the instance of @@ -495,7 +495,7 @@ instantiate(FuncBuilder *b, VectorTransferReadOp *read, VectorType hwVectorType, /// `hwVectorType` int the covering of th3e super-vector type. For a more /// detailed description of the problem, see the description of /// reindexAffineIndices. -static OperationStmt * +static OperationInst * instantiate(FuncBuilder *b, VectorTransferWriteOp *write, VectorType hwVectorType, ArrayRef<unsigned> hwVectorInstance, DenseMap<const Value *, Value *> *substitutionsMap) { @@ -508,7 +508,7 @@ instantiate(FuncBuilder *b, VectorTransferWriteOp *write, substitute(write->getVector(), hwVectorType, substitutionsMap), write->getMemRef(), affineIndices, projectedPermutationMap(write, hwVectorType)); - return cast<OperationStmt>(cloned->getOperation()); + return cast<OperationInst>(cloned->getOperation()); } /// Returns `true` if stmt instance is properly cloned and inserted, false @@ -544,7 +544,7 @@ static bool instantiateMaterialization(Statement *stmt, // Create a builder here for unroll-and-jam effects. FuncBuilder b(stmt); - auto *opStmt = cast<OperationStmt>(stmt); + auto *opStmt = cast<OperationInst>(stmt); if (auto write = opStmt->dyn_cast<VectorTransferWriteOp>()) { instantiate(&b, write, state->hwVectorType, state->hwVectorInstance, state->substitutionsMap); @@ -620,8 +620,7 @@ static bool emitSlice(MaterializationState *state, } LLVM_DEBUG(dbgs() << "\nMLFunction is now\n"); - LLVM_DEBUG( - cast<OperationStmt>((*slice)[0])->getOperationFunction()->print(dbgs())); + LLVM_DEBUG(cast<OperationInst>((*slice)[0])->getFunction()->print(dbgs())); // slice are topologically sorted, we can just erase them in reverse // order. Reverse iterator does not just work simply with an operator* @@ -652,7 +651,7 @@ static bool emitSlice(MaterializationState *state, /// because we currently disallow vectorization of defs that come from another /// scope. static bool materialize(MLFunction *f, - const SetVector<OperationStmt *> &terminators, + const SetVector<OperationInst *> &terminators, MaterializationState *state) { DenseSet<Statement *> seen; for (auto *term : terminators) { @@ -724,7 +723,7 @@ PassResult MaterializeVectorsPass::runOnMLFunction(MLFunction *f) { // Capture terminators; i.e. vector_transfer_write ops involving a strict // super-vector of subVectorType. auto filter = [subVectorType](const Statement &stmt) { - const auto &opStmt = cast<OperationStmt>(stmt); + const auto &opStmt = cast<OperationInst>(stmt); if (!opStmt.isa<VectorTransferWriteOp>()) { return false; } @@ -732,9 +731,9 @@ PassResult MaterializeVectorsPass::runOnMLFunction(MLFunction *f) { }; auto pat = Op(filter); auto matches = pat.match(f); - SetVector<OperationStmt *> terminators; + SetVector<OperationInst *> terminators; for (auto m : matches) { - terminators.insert(cast<OperationStmt>(m.first)); + terminators.insert(cast<OperationInst>(m.first)); } auto fail = materialize(f, terminators, &state); diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index ce2fac72933..0096cd7be2d 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -64,7 +64,7 @@ FunctionPass *mlir::createPipelineDataTransferPass() { // Returns the position of the tag memref operand given a DMA statement. // Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are // added. TODO(b/117228571) -static unsigned getTagMemRefPos(const OperationStmt &dmaStmt) { +static unsigned getTagMemRefPos(const OperationInst &dmaStmt) { assert(dmaStmt.isa<DmaStartOp>() || dmaStmt.isa<DmaWaitOp>()); if (dmaStmt.isa<DmaStartOp>()) { // Second to last operand. @@ -179,13 +179,13 @@ static bool checkTagMatch(OpPointer<DmaStartOp> startOp, // Identify matching DMA start/finish statements to overlap computation with. static void findMatchingStartFinishStmts( ForStmt *forStmt, - SmallVectorImpl<std::pair<OperationStmt *, OperationStmt *>> + SmallVectorImpl<std::pair<OperationInst *, OperationInst *>> &startWaitPairs) { // Collect outgoing DMA statements - needed to check for dependences below. SmallVector<OpPointer<DmaStartOp>, 4> outgoingDmaOps; for (auto &stmt : *forStmt->getBody()) { - auto *opStmt = dyn_cast<OperationStmt>(&stmt); + auto *opStmt = dyn_cast<OperationInst>(&stmt); if (!opStmt) continue; OpPointer<DmaStartOp> dmaStartOp; @@ -194,9 +194,9 @@ static void findMatchingStartFinishStmts( outgoingDmaOps.push_back(dmaStartOp); } - SmallVector<OperationStmt *, 4> dmaStartStmts, dmaFinishStmts; + SmallVector<OperationInst *, 4> dmaStartStmts, dmaFinishStmts; for (auto &stmt : *forStmt->getBody()) { - auto *opStmt = dyn_cast<OperationStmt>(&stmt); + auto *opStmt = dyn_cast<OperationInst>(&stmt); if (!opStmt) continue; // Collect DMA finish statements. @@ -260,7 +260,7 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) { return success(); } - SmallVector<std::pair<OperationStmt *, OperationStmt *>, 4> startWaitPairs; + SmallVector<std::pair<OperationInst *, OperationInst *>, 4> startWaitPairs; findMatchingStartFinishStmts(forStmt, startWaitPairs); if (startWaitPairs.empty()) { @@ -293,7 +293,7 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) { // operation could have been used on it if it was dynamically shaped in // order to create the double buffer above) if (oldMemRef->use_empty()) - if (auto *allocStmt = oldMemRef->getDefiningStmt()) + if (auto *allocStmt = oldMemRef->getDefiningInst()) allocStmt->erase(); } @@ -309,7 +309,7 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) { // If the old tag has no more uses, remove its 'dead' alloc if it was // alloc'ed. if (oldTagMemRef->use_empty()) - if (auto *allocStmt = oldTagMemRef->getDefiningStmt()) + if (auto *allocStmt = oldTagMemRef->getDefiningInst()) allocStmt->erase(); } @@ -329,7 +329,7 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) { } else { // If a slice wasn't created, the reachable affine_apply op's from its // operands are the ones that go with it. - SmallVector<OperationStmt *, 4> affineApplyStmts; + SmallVector<OperationInst *, 4> affineApplyStmts; SmallVector<Value *, 4> operands(dmaStartStmt->getOperands()); getReachableAffineApplyOps(operands, affineApplyStmts); for (const auto *stmt : affineApplyStmts) { @@ -352,7 +352,7 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) { shifts[s++] = stmtShiftMap[&stmt]; LLVM_DEBUG( // Tagging statements with shifts for debugging purposes. - if (auto *opStmt = dyn_cast<OperationStmt>(&stmt)) { + if (auto *opStmt = dyn_cast<OperationInst>(&stmt)) { FuncBuilder b(opStmt); opStmt->setAttr(b.getIdentifier("shift"), b.getI64IntegerAttr(shifts[s - 1])); diff --git a/mlir/lib/Transforms/SimplifyAffineExpr.cpp b/mlir/lib/Transforms/SimplifyAffineExpr.cpp index 048e26ae115..b0b31e01175 100644 --- a/mlir/lib/Transforms/SimplifyAffineExpr.cpp +++ b/mlir/lib/Transforms/SimplifyAffineExpr.cpp @@ -47,7 +47,7 @@ struct SimplifyAffineStructures : public FunctionPass, PassResult runOnCFGFunction(CFGFunction *f) override { return success(); } void visitIfStmt(IfStmt *ifStmt); - void visitOperationStmt(OperationStmt *opStmt); + void visitOperationInst(OperationInst *opStmt); static char passID; }; @@ -75,7 +75,7 @@ void SimplifyAffineStructures::visitIfStmt(IfStmt *ifStmt) { ifStmt->setIntegerSet(simplifyIntegerSet(set)); } -void SimplifyAffineStructures::visitOperationStmt(OperationStmt *opStmt) { +void SimplifyAffineStructures::visitOperationInst(OperationInst *opStmt) { for (auto attr : opStmt->getAttrs()) { if (auto mapAttr = attr.second.dyn_cast<AffineMapAttr>()) { MutableAffineMap mMap(mapAttr.getValue()); diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index a690844f7a6..f493e4b090b 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -39,7 +39,7 @@ public: void simplifyFunction(Function *currentFunction, WorklistRewriter &rewriter); - void addToWorklist(Operation *op) { + void addToWorklist(OperationInst *op) { // Check to see if the worklist already contains this op. if (worklistMap.count(op)) return; @@ -48,7 +48,7 @@ public: worklist.push_back(op); } - Operation *popFromWorklist() { + OperationInst *popFromWorklist() { auto *op = worklist.back(); worklist.pop_back(); @@ -60,7 +60,7 @@ public: /// If the specified operation is in the worklist, remove it. If not, this is /// a no-op. - void removeFromWorklist(Operation *op) { + void removeFromWorklist(OperationInst *op) { auto it = worklistMap.find(op); if (it != worklistMap.end()) { assert(worklist[it->second] == op && "malformed worklist data structure"); @@ -76,13 +76,13 @@ private: /// need to be revisited, plus their index in the worklist. This allows us to /// efficiently remove operations from the worklist when they are removed even /// if they aren't the root of a pattern. - std::vector<Operation *> worklist; - DenseMap<Operation *, unsigned> worklistMap; + std::vector<OperationInst *> worklist; + DenseMap<OperationInst *, unsigned> worklistMap; /// As part of canonicalization, we move constants to the top of the entry /// block of the current function and de-duplicate them. This keeps track of /// constants we have done this for. - DenseMap<std::pair<Attribute, Type>, Operation *> uniquedConstants; + DenseMap<std::pair<Attribute, Type>, OperationInst *> uniquedConstants; }; }; // end anonymous namespace @@ -94,22 +94,22 @@ public: WorklistRewriter(GreedyPatternRewriteDriver &driver, MLIRContext *context) : PatternRewriter(context), driver(driver) {} - virtual void setInsertionPoint(Operation *op) = 0; + virtual void setInsertionPoint(OperationInst *op) = 0; // If an operation is about to be removed, make sure it is not in our // worklist anymore because we'd get dangling references to it. - void notifyOperationRemoved(Operation *op) override { + void notifyOperationRemoved(OperationInst *op) override { driver.removeFromWorklist(op); } // When the root of a pattern is about to be replaced, it can trigger // simplifications to its users - make sure to add them to the worklist // before the root is changed. - void notifyRootReplaced(Operation *op) override { + void notifyRootReplaced(OperationInst *op) override { for (auto *result : op->getResults()) // TODO: Add a result->getUsers() iterator. for (auto &user : result->getUses()) { - if (auto *op = dyn_cast<Operation>(user.getOwner())) + if (auto *op = dyn_cast<OperationInst>(user.getOwner())) driver.addToWorklist(op); } @@ -168,7 +168,6 @@ void GreedyPatternRewriteDriver::simplifyFunction(Function *currentFunction, // canonical version. To ensure safe dominance, move the operation to the // top of the function. entry = op; - auto &entryBB = currentFunction->front(); op->moveBefore(&entryBB, entryBB.begin()); continue; @@ -186,7 +185,7 @@ void GreedyPatternRewriteDriver::simplifyFunction(Function *currentFunction, operandConstants.clear(); for (auto *operand : op->getOperands()) { Attribute operandCst; - if (auto *operandOp = operand->getDefiningOperation()) { + if (auto *operandOp = operand->getDefiningInst()) { if (auto operandConstantOp = operandOp->dyn_cast<ConstantOp>()) operandCst = operandConstantOp->getValue(); } @@ -219,7 +218,7 @@ void GreedyPatternRewriteDriver::simplifyFunction(Function *currentFunction, // // TODO: Add a result->getUsers() iterator. for (auto &operand : op->getResult(i)->getUses()) { - if (auto *op = dyn_cast<Operation>(operand.getOwner())) + if (auto *op = dyn_cast<OperationInst>(operand.getOwner())) addToWorklist(op); } @@ -265,15 +264,15 @@ static void processMLFunction(MLFunction *fn, // Implement the hook for creating operations, and make sure that newly // created ops are added to the worklist for processing. - Operation *createOperation(const OperationState &state) override { + OperationInst *createOperation(const OperationState &state) override { auto *result = builder.createOperation(state); driver.addToWorklist(result); return result; } - void setInsertionPoint(Operation *op) override { + void setInsertionPoint(OperationInst *op) override { // Any new operations should be added before this statement. - builder.setInsertionPoint(cast<OperationStmt>(op)); + builder.setInsertionPoint(cast<OperationInst>(op)); } private: @@ -281,7 +280,7 @@ static void processMLFunction(MLFunction *fn, }; GreedyPatternRewriteDriver driver(std::move(patterns)); - fn->walk([&](OperationStmt *stmt) { driver.addToWorklist(stmt); }); + fn->walk([&](OperationInst *stmt) { driver.addToWorklist(stmt); }); FuncBuilder mlBuilder(fn); MLFuncRewriter rewriter(driver, mlBuilder); @@ -297,13 +296,13 @@ static void processCFGFunction(CFGFunction *fn, // Implement the hook for creating operations, and make sure that newly // created ops are added to the worklist for processing. - Operation *createOperation(const OperationState &state) override { + OperationInst *createOperation(const OperationState &state) override { auto *result = builder.createOperation(state); driver.addToWorklist(result); return result; } - void setInsertionPoint(Operation *op) override { + void setInsertionPoint(OperationInst *op) override { // Any new operations should be added before this instruction. builder.setInsertionPoint(cast<OperationInst>(op)); } @@ -315,7 +314,7 @@ static void processCFGFunction(CFGFunction *fn, GreedyPatternRewriteDriver driver(std::move(patterns)); for (auto &bb : *fn) for (auto &op : bb) - if (auto *opInst = dyn_cast<OperationStmt>(&op)) + if (auto *opInst = dyn_cast<OperationInst>(&op)) driver.addToWorklist(opInst); FuncBuilder cfgBuilder(fn); diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index b92e15d7857..4a2831c0a83 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -156,7 +156,7 @@ generateLoop(AffineMap lbMap, AffineMap ubMap, auto *loopChunk = b->createFor(srcForStmt->getLoc(), lbOperands, lbMap, ubOperands, ubMap, srcForStmt->getStep()); - OperationStmt::OperandMapTy operandMap; + OperationInst::OperandMapTy operandMap; for (auto it = stmtGroupQueue.begin() + offset, e = stmtGroupQueue.end(); it != e; ++it) { diff --git a/mlir/lib/Transforms/Utils/LoweringUtils.cpp b/mlir/lib/Transforms/Utils/LoweringUtils.cpp index 90f4d0c028d..6fca54a9972 100644 --- a/mlir/lib/Transforms/Utils/LoweringUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoweringUtils.cpp @@ -124,7 +124,7 @@ bool mlir::expandAffineApply(AffineApplyOp *op) { if (!op) return true; - FuncBuilder builder(cast<OperationStmt>(op->getOperation())); + FuncBuilder builder(cast<OperationInst>(op->getOperation())); auto affineMap = op->getAffineMap(); for (auto numberedExpr : llvm::enumerate(affineMap.getResults())) { Value *expanded = expandAffineExpr(&builder, numberedExpr.value(), op); diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp index 2bc1be1b785..c8317c27f74 100644 --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -36,7 +36,7 @@ using namespace mlir; /// Return true if this operation dereferences one or more memref's. // Temporary utility: will be replaced when this is modeled through // side-effects/op traits. TODO(b/117228571) -static bool isMemRefDereferencingOp(const Operation &op) { +static bool isMemRefDereferencingOp(const OperationInst &op) { if (op.isa<LoadOp>() || op.isa<StoreOp>() || op.isa<DmaStartOp>() || op.isa<DmaWaitOp>()) return true; @@ -82,10 +82,10 @@ bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef, assert(oldMemRef->getType().cast<MemRefType>().getElementType() == newMemRef->getType().cast<MemRefType>().getElementType()); - // Walk all uses of old memref. Statement using the memref gets replaced. + // Walk all uses of old memref. Operation using the memref gets replaced. for (auto it = oldMemRef->use_begin(); it != oldMemRef->use_end();) { StmtOperand &use = *(it++); - auto *opStmt = cast<OperationStmt>(use.getOwner()); + auto *opStmt = cast<OperationInst>(use.getOwner()); // Skip this use if it's not dominated by domStmtFilter. if (domStmtFilter && !dominates(*domStmtFilter, *opStmt)) @@ -124,7 +124,7 @@ bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef, // TODO(mlir-team): An operation/SSA value should provide a method to // return the position of an SSA result in its defining // operation. - assert(extraIndex->getDefiningStmt()->getNumResults() == 1 && + assert(extraIndex->getDefiningInst()->getNumResults() == 1 && "single result op's expected to generate these indices"); assert((extraIndex->isValidDim() || extraIndex->isValidSymbol()) && "invalid memory op index"); @@ -186,10 +186,10 @@ bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef, // operands were drawing results from multiple affine apply ops, this also leads // to a collapse into a single affine apply op. The final results of the // composed AffineApplyOp are returned in output parameter 'results'. -OperationStmt * +OperationInst * mlir::createComposedAffineApplyOp(FuncBuilder *builder, Location loc, ArrayRef<Value *> operands, - ArrayRef<OperationStmt *> affineApplyOps, + ArrayRef<OperationInst *> affineApplyOps, SmallVectorImpl<Value *> *results) { // Create identity map with same number of dimensions as number of operands. auto map = builder->getMultiDimIdentityMap(operands.size()); @@ -216,7 +216,7 @@ mlir::createComposedAffineApplyOp(FuncBuilder *builder, Location loc, for (unsigned i = 0, e = operands.size(); i < e; ++i) { (*results)[i] = affineApplyOp->getResult(i); } - return cast<OperationStmt>(affineApplyOp->getOperation()); + return cast<OperationInst>(affineApplyOp->getOperation()); } /// Given an operation statement, inserts a new single affine apply operation, @@ -247,19 +247,19 @@ mlir::createComposedAffineApplyOp(FuncBuilder *builder, Location loc, /// all the affine_apply op's supplying operands to this opStmt do not have any /// uses besides this opStmt. Returns the new affine_apply operation statement /// otherwise. -OperationStmt *mlir::createAffineComputationSlice(OperationStmt *opStmt) { +OperationInst *mlir::createAffineComputationSlice(OperationInst *opStmt) { // Collect all operands that are results of affine apply ops. SmallVector<Value *, 4> subOperands; subOperands.reserve(opStmt->getNumOperands()); for (auto *operand : opStmt->getOperands()) { - auto *defStmt = operand->getDefiningStmt(); + auto *defStmt = operand->getDefiningInst(); if (defStmt && defStmt->isa<AffineApplyOp>()) { subOperands.push_back(operand); } } // Gather sequence of AffineApplyOps reachable from 'subOperands'. - SmallVector<OperationStmt *, 4> affineApplyOps; + SmallVector<OperationInst *, 4> affineApplyOps; getReachableAffineApplyOps(subOperands, affineApplyOps); // Skip transforming if there are no affine maps to compose. if (affineApplyOps.empty()) @@ -313,11 +313,11 @@ OperationStmt *mlir::createAffineComputationSlice(OperationStmt *opStmt) { } void mlir::forwardSubstitute(OpPointer<AffineApplyOp> affineApplyOp) { - if (!affineApplyOp->getOperation()->getOperationFunction()->isML()) { + if (!affineApplyOp->getOperation()->getFunction()->isML()) { // TODO: Support forward substitution for CFG style functions. return; } - auto *opStmt = cast<OperationStmt>(affineApplyOp->getOperation()); + auto *opStmt = cast<OperationInst>(affineApplyOp->getOperation()); // Iterate through all uses of all results of 'opStmt', forward substituting // into any uses which are AffineApplyOps. for (unsigned resultIndex = 0, e = opStmt->getNumResults(); resultIndex < e; @@ -326,7 +326,7 @@ void mlir::forwardSubstitute(OpPointer<AffineApplyOp> affineApplyOp) { for (auto it = result->use_begin(); it != result->use_end();) { StmtOperand &use = *(it++); auto *useStmt = use.getOwner(); - auto *useOpStmt = dyn_cast<OperationStmt>(useStmt); + auto *useOpStmt = dyn_cast<OperationInst>(useStmt); // Skip if use is not AffineApplyOp. if (useOpStmt == nullptr || !useOpStmt->isa<AffineApplyOp>()) continue; @@ -379,7 +379,7 @@ bool mlir::constantFoldBounds(ForStmt *forStmt) { : forStmt->getUpperBoundOperands(); for (const auto *operand : boundOperands) { Attribute operandCst; - if (auto *operandOp = operand->getDefiningOperation()) { + if (auto *operandOp = operand->getDefiningInst()) { if (auto operandConstantOp = operandOp->dyn_cast<ConstantOp>()) operandCst = operandConstantOp->getValue(); } @@ -415,7 +415,8 @@ bool mlir::constantFoldBounds(ForStmt *forStmt) { } void mlir::remapFunctionAttrs( - Operation &op, const DenseMap<Attribute, FunctionAttr> &remappingTable) { + OperationInst &op, + const DenseMap<Attribute, FunctionAttr> &remappingTable) { for (auto attr : op.getAttrs()) { // Do the remapping, if we got the same thing back, then it must contain // functions that aren't getting remapped. @@ -451,7 +452,7 @@ void mlir::remapFunctionAttrs( struct MLFnWalker : public StmtWalker<MLFnWalker> { MLFnWalker(const DenseMap<Attribute, FunctionAttr> &remappingTable) : remappingTable(remappingTable) {} - void visitOperationStmt(OperationStmt *opStmt) { + void visitOperationInst(OperationInst *opStmt) { remapFunctionAttrs(*opStmt, remappingTable); } diff --git a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp index b8145126770..5abd3a3cfcc 100644 --- a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp +++ b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp @@ -98,7 +98,7 @@ void VectorizerTestPass::testVectorShapeRatio(MLFunction *f) { // Only filter statements that operate on a strict super-vector and have one // return. This makes testing easier. auto filter = [subVectorType](const Statement &stmt) { - auto *opStmt = dyn_cast<OperationStmt>(&stmt); + auto *opStmt = dyn_cast<OperationInst>(&stmt); if (!opStmt) { return false; } @@ -116,7 +116,7 @@ void VectorizerTestPass::testVectorShapeRatio(MLFunction *f) { auto pat = Op(filter); auto matches = pat.match(f); for (auto m : matches) { - auto *opStmt = cast<OperationStmt>(m.first); + auto *opStmt = cast<OperationInst>(m.first); // This is a unit test that only checks and prints shape ratio. // As a consequence we write only Ops with a single return type for the // purpose of this test. If we need to test more intricate behavior in the @@ -146,7 +146,7 @@ static MLFunctionMatches matchTestSlicingOps(MLFunction *f) { using matcher::Op; // Match all OpStatements with the kTestSlicingOpName name. auto filter = [](const Statement &stmt) { - const auto &opStmt = cast<OperationStmt>(stmt); + const auto &opStmt = cast<OperationInst>(stmt); return opStmt.getName().getStringRef() == kTestSlicingOpName; }; auto pat = Op(filter); @@ -192,7 +192,7 @@ void VectorizerTestPass::testSlicing(MLFunction *f) { } bool customOpWithAffineMapAttribute(const Statement &stmt) { - const auto &opStmt = cast<OperationStmt>(stmt); + const auto &opStmt = cast<OperationInst>(stmt); return opStmt.getName().getStringRef() == VectorizerTestPass::kTestAffineMapOpName; } @@ -205,7 +205,7 @@ void VectorizerTestPass::testComposeMaps(MLFunction *f) { maps.reserve(matches.size()); std::reverse(matches.begin(), matches.end()); for (auto m : matches) { - auto *opStmt = cast<OperationStmt>(m.first); + auto *opStmt = cast<OperationInst>(m.first); auto map = opStmt->getAttr(VectorizerTestPass::kTestAffineMapAttrName) .cast<AffineMapAttr>() .getValue(); diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index 80d16475e47..0efe727f5b4 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -722,22 +722,22 @@ namespace { struct VectorizationState { /// Adds an entry of pre/post vectorization statements in the state. - void registerReplacement(OperationStmt *key, OperationStmt *value); + void registerReplacement(OperationInst *key, OperationInst *value); /// When the current vectorization pattern is successful, this erases the /// instructions that were marked for erasure in the proper order and resets /// the internal state for the next pattern. void finishVectorizationPattern(); - // In-order tracking of original OperationStmt that have been vectorized. + // In-order tracking of original OperationInst that have been vectorized. // Erase in reverse order. - SmallVector<OperationStmt *, 16> toErase; - // Set of OperationStmt that have been vectorized (the values in the + SmallVector<OperationInst *, 16> toErase; + // Set of OperationInst that have been vectorized (the values in the // vectorizationMap for hashed access). The vectorizedSet is used in // particular to filter the statements that have already been vectorized by // this pattern, when iterating over nested loops in this pattern. - DenseSet<OperationStmt *> vectorizedSet; - // Map of old scalar OperationStmt to new vectorized OperationStmt. - DenseMap<OperationStmt *, OperationStmt *> vectorizationMap; + DenseSet<OperationInst *> vectorizedSet; + // Map of old scalar OperationInst to new vectorized OperationInst. + DenseMap<OperationInst *, OperationInst *> vectorizationMap; // Map of old scalar Value to new vectorized Value. DenseMap<const Value *, Value *> replacementMap; // The strategy drives which loop to vectorize by which amount. @@ -746,17 +746,17 @@ struct VectorizationState { // vectorizeOperations function. They consist of the subset of load operations // that have been vectorized. They can be retrieved from `vectorizationMap` // but it is convenient to keep track of them in a separate data structure. - DenseSet<OperationStmt *> roots; + DenseSet<OperationInst *> roots; // Terminator statements for the worklist in the vectorizeOperations function. // They consist of the subset of store operations that have been vectorized. // They can be retrieved from `vectorizationMap` but it is convenient to keep // track of them in a separate data structure. Since they do not necessarily // belong to use-def chains starting from loads (e.g storing a constant), we // need to handle them in a post-pass. - DenseSet<OperationStmt *> terminators; + DenseSet<OperationInst *> terminators; // Checks that the type of `stmt` is StoreOp and adds it to the terminators // set. - void registerTerminator(OperationStmt *stmt); + void registerTerminator(OperationInst *stmt); private: void registerReplacement(const Value *key, Value *value); @@ -764,8 +764,8 @@ private: } // end namespace -void VectorizationState::registerReplacement(OperationStmt *key, - OperationStmt *value) { +void VectorizationState::registerReplacement(OperationInst *key, + OperationInst *value) { LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ commit vectorized op: "); LLVM_DEBUG(key->print(dbgs())); LLVM_DEBUG(dbgs() << " into "); @@ -784,7 +784,7 @@ void VectorizationState::registerReplacement(OperationStmt *key, } } -void VectorizationState::registerTerminator(OperationStmt *stmt) { +void VectorizationState::registerTerminator(OperationInst *stmt) { assert(stmt->isa<StoreOp>() && "terminator must be a StoreOp"); assert(terminators.count(stmt) == 0 && "terminator was already inserted previously"); @@ -832,7 +832,7 @@ static bool vectorizeRootOrTerminal(Value *iv, LoadOrStoreOpPointer memoryOp, auto vectorType = VectorType::get(state->strategy->vectorSizes, elementType); // Materialize a MemRef with 1 vector. - auto *opStmt = cast<OperationStmt>(memoryOp->getOperation()); + auto *opStmt = cast<OperationInst>(memoryOp->getOperation()); // For now, vector_transfers must be aligned, operate only on indices with an // identity subset of AffineMap and do not change layout. // TODO(ntv): increase the expressiveness power of vector_transfer operations @@ -847,7 +847,7 @@ static bool vectorizeRootOrTerminal(Value *iv, LoadOrStoreOpPointer memoryOp, opStmt->getLoc(), vectorType, memoryOp->getMemRef(), map(makePtrDynCaster<Value>(), memoryOp->getIndices()), permutationMap); state->registerReplacement(opStmt, - cast<OperationStmt>(transfer->getOperation())); + cast<OperationInst>(transfer->getOperation())); } else { state->registerTerminator(opStmt); } @@ -866,7 +866,7 @@ static bool vectorizeForStmt(ForStmt *loop, int64_t step, if (!matcher::isLoadOrStore(stmt)) { return false; } - auto *opStmt = cast<OperationStmt>(&stmt); + auto *opStmt = cast<OperationInst>(&stmt); return state->vectorizationMap.count(opStmt) == 0 && state->vectorizedSet.count(opStmt) == 0 && state->roots.count(opStmt) == 0 && @@ -875,7 +875,7 @@ static bool vectorizeForStmt(ForStmt *loop, int64_t step, auto loadAndStores = matcher::Op(notVectorizedThisPattern); auto matches = loadAndStores.match(loop); for (auto ls : matches) { - auto *opStmt = cast<OperationStmt>(ls.first); + auto *opStmt = cast<OperationInst>(ls.first); auto load = opStmt->dyn_cast<LoadOp>(); auto store = opStmt->dyn_cast<StoreOp>(); LLVM_DEBUG(opStmt->print(dbgs())); @@ -974,14 +974,14 @@ static Value *vectorizeConstant(Statement *stmt, const ConstantOp &constant, Location loc = stmt->getLoc(); auto vectorType = type.cast<VectorType>(); auto attr = SplatElementsAttr::get(vectorType, constant.getValue()); - auto *constantOpStmt = cast<OperationStmt>(constant.getOperation()); + auto *constantOpStmt = cast<OperationInst>(constant.getOperation()); OperationState state( b.getContext(), loc, constantOpStmt->getName().getStringRef(), {}, {vectorType}, {make_pair(Identifier::get("value", b.getContext()), attr)}); - auto *splat = cast<OperationStmt>(b.createOperation(state)); + auto *splat = cast<OperationInst>(b.createOperation(state)); return splat->getResult(0); } @@ -994,7 +994,7 @@ static Type getVectorType(Value *v, const VectorizationState &state) { if (!VectorType::isValidElementType(v->getType())) { return Type(); } - auto *definingOpStmt = cast<OperationStmt>(v->getDefiningStmt()); + auto *definingOpStmt = cast<OperationInst>(v->getDefiningInst()); if (state.vectorizedSet.count(definingOpStmt) > 0) { return v->getType().cast<VectorType>(); } @@ -1026,7 +1026,7 @@ static Value *vectorizeOperand(Value *operand, Statement *stmt, VectorizationState *state) { LLVM_DEBUG(dbgs() << "\n[early-vect]vectorize operand: "); LLVM_DEBUG(operand->print(dbgs())); - auto *definingStatement = cast<OperationStmt>(operand->getDefiningStmt()); + auto *definingStatement = cast<OperationInst>(operand->getDefiningInst()); // 1. If this value has already been vectorized this round, we are done. if (state->vectorizedSet.count(definingStatement) > 0) { LLVM_DEBUG(dbgs() << " -> already vector operand"); @@ -1049,7 +1049,7 @@ static Value *vectorizeOperand(Value *operand, Statement *stmt, return nullptr; } // 3. vectorize constant. - if (auto constant = operand->getDefiningStmt()->dyn_cast<ConstantOp>()) { + if (auto constant = operand->getDefiningInst()->dyn_cast<ConstantOp>()) { return vectorizeConstant(stmt, *constant, getVectorType(operand, *state).cast<VectorType>()); } @@ -1059,17 +1059,17 @@ static Value *vectorizeOperand(Value *operand, Statement *stmt, return nullptr; }; -/// Encodes OperationStmt-specific behavior for vectorization. In general we +/// Encodes OperationInst-specific behavior for vectorization. In general we /// assume that all operands of an op must be vectorized but this is not always /// true. In the future, it would be nice to have a trait that describes how a /// particular operation vectorizes. For now we implement the case distinction /// here. -/// Returns a vectorized form of stmt or nullptr if vectorization fails. +/// Returns a vectorized form of an operation or nullptr if vectorization fails. /// TODO(ntv): consider adding a trait to Op to describe how it gets vectorized. /// Maybe some Ops are not vectorizable or require some tricky logic, we cannot /// do one-off logic here; ideally it would be TableGen'd. -static OperationStmt *vectorizeOneOperationStmt(FuncBuilder *b, - OperationStmt *opStmt, +static OperationInst *vectorizeOneOperationInst(FuncBuilder *b, + OperationInst *opStmt, VectorizationState *state) { // Sanity checks. assert(!opStmt->isa<LoadOp>() && @@ -1091,7 +1091,7 @@ static OperationStmt *vectorizeOneOperationStmt(FuncBuilder *b, LLVM_DEBUG(permutationMap.print(dbgs())); auto transfer = b.create<VectorTransferWriteOp>( opStmt->getLoc(), vectorValue, memRef, indices, permutationMap); - auto *res = cast<OperationStmt>(transfer->getOperation()); + auto *res = cast<OperationInst>(transfer->getOperation()); LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ vectorized store: " << *res); // "Terminators" (i.e. StoreOps) are erased on the spot. opStmt->erase(); @@ -1114,8 +1114,8 @@ static OperationStmt *vectorizeOneOperationStmt(FuncBuilder *b, // Create a clone of the op with the proper operands and return types. // TODO(ntv): The following assumes there is always an op with a fixed // name that works both in scalar mode and vector mode. - // TODO(ntv): Is it worth considering an OperationStmt.clone operation - // which changes the type so we can promote an OperationStmt with less + // TODO(ntv): Is it worth considering an OperationInst.clone operation + // which changes the type so we can promote an OperationInst with less // boilerplate? OperationState newOp(b->getContext(), opStmt->getLoc(), opStmt->getName().getStringRef(), operands, types, @@ -1123,22 +1123,22 @@ static OperationStmt *vectorizeOneOperationStmt(FuncBuilder *b, return b->createOperation(newOp); } -/// Iterates over the OperationStmt in the loop and rewrites them using their +/// Iterates over the OperationInst in the loop and rewrites them using their /// vectorized counterpart by: -/// 1. iteratively building a worklist of uses of the OperationStmt vectorized +/// 1. iteratively building a worklist of uses of the OperationInst vectorized /// so far by this pattern; -/// 2. for each OperationStmt in the worklist, create the vector form of this +/// 2. for each OperationInst in the worklist, create the vector form of this /// operation and replace all its uses by the vectorized form. For this step, /// the worklist must be traversed in order; /// 3. verify that all operands of the newly vectorized operation have been /// vectorized by this pattern. static bool vectorizeOperations(VectorizationState *state) { // 1. create initial worklist with the uses of the roots. - SetVector<OperationStmt *> worklist; - auto insertUsesOf = [&worklist, state](Operation *vectorized) { - for (auto *r : cast<OperationStmt>(vectorized)->getResults()) + SetVector<OperationInst *> worklist; + auto insertUsesOf = [&worklist, state](OperationInst *vectorized) { + for (auto *r : vectorized->getResults()) for (auto &u : r->getUses()) { - auto *stmt = cast<OperationStmt>(u.getOwner()); + auto *stmt = cast<OperationInst>(u.getOwner()); // Don't propagate to terminals, a separate pass is needed for those. // TODO(ntv)[b/119759136]: use isa<> once Op is implemented. if (state->terminators.count(stmt) > 0) { @@ -1160,7 +1160,7 @@ static bool vectorizeOperations(VectorizationState *state) { // 2. Create vectorized form of the statement. // Insert it just before stmt, on success register stmt as replaced. FuncBuilder b(stmt); - auto *vectorizedStmt = vectorizeOneOperationStmt(&b, stmt, state); + auto *vectorizedStmt = vectorizeOneOperationInst(&b, stmt, state); if (!vectorizedStmt) { return true; } @@ -1169,11 +1169,11 @@ static bool vectorizeOperations(VectorizationState *state) { // Note that we cannot just call replaceAllUsesWith because it may // result in ops with mixed types, for ops whose operands have not all // yet been vectorized. This would be invalid IR. - state->registerReplacement(cast<OperationStmt>(stmt), vectorizedStmt); + state->registerReplacement(stmt, vectorizedStmt); // 4. Augment the worklist with uses of the statement we just vectorized. // This preserves the proper order in the worklist. - apply(insertUsesOf, ArrayRef<Operation *>{stmt}); + apply(insertUsesOf, ArrayRef<OperationInst *>{stmt}); } return false; } @@ -1223,12 +1223,12 @@ static bool vectorizeRootMatches(MLFunctionMatches matches, // Form the root operationsthat have been set in the replacementMap. // For now, these roots are the loads for which vector_transfer_read // operations have been inserted. - auto getDefiningOperation = [](const Value *val) { - return const_cast<Value *>(val)->getDefiningOperation(); + auto getDefiningInst = [](const Value *val) { + return const_cast<Value *>(val)->getDefiningInst(); }; using ReferenceTy = decltype(*(state.replacementMap.begin())); auto getKey = [](ReferenceTy it) { return it.first; }; - auto roots = map(getDefiningOperation, map(getKey, state.replacementMap)); + auto roots = map(getDefiningInst, map(getKey, state.replacementMap)); // Vectorize the root operations and everything reached by use-def chains // except the terminators (store statements) that need to be post-processed @@ -1240,12 +1240,12 @@ static bool vectorizeRootMatches(MLFunctionMatches matches, } // Finally, vectorize the terminators. If anything fails to vectorize, skip. - auto vectorizeOrFail = [&fail, &state](OperationStmt *stmt) { + auto vectorizeOrFail = [&fail, &state](OperationInst *stmt) { if (fail) { return; } FuncBuilder b(stmt); - auto *res = vectorizeOneOperationStmt(&b, stmt, &state); + auto *res = vectorizeOneOperationInst(&b, stmt, &state); if (res == nullptr) { fail = true; } |

