diff options
| author | River Riddle <riverriddle@google.com> | 2019-02-03 10:03:46 -0800 |
|---|---|---|
| committer | jpienaar <jpienaar@google.com> | 2019-03-29 16:09:36 -0700 |
| commit | 870d7783503962a7043b2654ab82a9d4f4f1a961 (patch) | |
| tree | eec8283dd8f17286c13360aa07cb9e1412e59a5e | |
| parent | de2d0dfbcab7e79bb6238e0a105f2747783246eb (diff) | |
| download | bcm5719-llvm-870d7783503962a7043b2654ab82a9d4f4f1a961.tar.gz bcm5719-llvm-870d7783503962a7043b2654ab82a9d4f4f1a961.zip | |
Begin the process of fully removing OperationInst. This patch cleans up references to OperationInst in the /include, /AffineOps, and lib/Analysis.
PiperOrigin-RevId: 232199262
44 files changed, 356 insertions, 446 deletions
diff --git a/mlir/include/mlir/AffineOps/AffineOps.h b/mlir/include/mlir/AffineOps/AffineOps.h index 955c2761ac6..12e4589405f 100644 --- a/mlir/include/mlir/AffineOps/AffineOps.h +++ b/mlir/include/mlir/AffineOps/AffineOps.h @@ -182,15 +182,15 @@ public: /// Walk the operation instructions in the 'for' instruction in preorder, /// calling the callback for each operation. - void walkOps(std::function<void(OperationInst *)> callback); + void walkOps(std::function<void(Instruction *)> callback); /// Walk the operation instructions in the 'for' instruction in postorder, /// calling the callback for each operation. - void walkOpsPostOrder(std::function<void(OperationInst *)> callback); + void walkOpsPostOrder(std::function<void(Instruction *)> callback); private: friend class Instruction; - explicit AffineForOp(const OperationInst *state) : Op(state) {} + explicit AffineForOp(const Instruction *state) : Op(state) {} }; /// Returns if the provided value is the induction variable of a AffineForOp. @@ -224,13 +224,11 @@ public: using operand_range = AffineForOp::operand_range; operand_iterator operand_begin() const { - return const_cast<OperationInst *>(inst->getInstruction()) - ->operand_begin() + + return const_cast<Instruction *>(inst->getInstruction())->operand_begin() + opStart; } operand_iterator operand_end() const { - return const_cast<OperationInst *>(inst->getInstruction()) - ->operand_begin() + + return const_cast<Instruction *>(inst->getInstruction())->operand_begin() + opEnd; } operand_range getOperands() const { return {operand_begin(), operand_end()}; } @@ -300,7 +298,7 @@ public: private: friend class Instruction; - explicit AffineIfOp(const OperationInst *state) : Op(state) {} + explicit AffineIfOp(const Instruction *state) : Op(state) {} }; } // end namespace mlir diff --git a/mlir/include/mlir/Analysis/AffineAnalysis.h b/mlir/include/mlir/Analysis/AffineAnalysis.h index 5eaf4e11ed4..ca420bab7e1 100644 --- a/mlir/include/mlir/Analysis/AffineAnalysis.h +++ b/mlir/include/mlir/Analysis/AffineAnalysis.h @@ -41,7 +41,6 @@ class Instruction; class IntegerSet; class Location; class MLIRContext; -using OperationInst = Instruction; template <typename OpType> class OpPointer; class Value; @@ -74,11 +73,11 @@ void fullyComposeAffineMapAndOperands(AffineMap *map, llvm::SmallVectorImpl<Value *> *operands); /// Returns in `affineApplyOps`, the sequence of those AffineApplyOp -/// OperationInsts that are reachable via a search starting from `operands` and +/// Instructions that are reachable via a search starting from `operands` and /// ending at those operands that are not the result of an AffineApplyOp. void getReachableAffineApplyOps( llvm::ArrayRef<Value *> operands, - llvm::SmallVectorImpl<OperationInst *> &affineApplyOps); + llvm::SmallVectorImpl<Instruction *> &affineApplyOps); /// Flattens 'expr' into 'flattenedExpr'. Returns true on success or false /// if 'expr' could not be flattened (i.e., semi-affine is not yet handled). @@ -119,13 +118,13 @@ bool getIndexSet(llvm::MutableArrayRef<OpPointer<AffineForOp>> forOps, /// Encapsulates a memref load or store access information. struct MemRefAccess { const Value *memref; - const OperationInst *opInst; + const Instruction *opInst; llvm::SmallVector<Value *, 4> indices; /// Constructs a MemRefAccess from a load or store operation instruction. // TODO(b/119949820): add accessors to standard op's load, store, DMA op's to // return MemRefAccess, i.e., loadOp->getAccess(), dmaOp->getRead/WriteAccess. - explicit MemRefAccess(OperationInst *opInst); + explicit MemRefAccess(Instruction *opInst); /// Populates 'accessMap' with composition of AffineApplyOps reachable from // 'indices'. diff --git a/mlir/include/mlir/Analysis/LoopAnalysis.h b/mlir/include/mlir/Analysis/LoopAnalysis.h index ac07fa349f8..e15fffc7d51 100644 --- a/mlir/include/mlir/Analysis/LoopAnalysis.h +++ b/mlir/include/mlir/Analysis/LoopAnalysis.h @@ -34,7 +34,6 @@ class AffineMap; template <typename T> class ConstOpPointer; class Instruction; class MemRefType; -using OperationInst = Instruction; class Value; /// Returns the trip count of the loop as an affine expression if the latter is diff --git a/mlir/include/mlir/Analysis/NestedMatcher.h b/mlir/include/mlir/Analysis/NestedMatcher.h index 0e41058f777..5c040ecbe08 100644 --- a/mlir/include/mlir/Analysis/NestedMatcher.h +++ b/mlir/include/mlir/Analysis/NestedMatcher.h @@ -97,7 +97,7 @@ private: using FilterFunctionType = std::function<bool(const Instruction &)>; static bool defaultFilterFunction(const Instruction &) { return true; }; struct NestedPattern { - NestedPattern(Instruction::Kind k, ArrayRef<NestedPattern> nested, + NestedPattern(ArrayRef<NestedPattern> nested, FilterFunctionType filter = defaultFilterFunction); NestedPattern(const NestedPattern &) = default; NestedPattern &operator=(const NestedPattern &) = default; @@ -127,7 +127,7 @@ private: struct State : public InstWalker<State> { State(NestedPattern &pattern, SmallVectorImpl<NestedMatch> *matches) : pattern(pattern), matches(matches) {} - void visitOperationInst(OperationInst *opInst) { + void visitOperationInst(Instruction *opInst) { pattern.matchOne(opInst, matches); } @@ -143,9 +143,6 @@ private: /// result. void matchOne(Instruction *inst, SmallVectorImpl<NestedMatch> *matches); - /// Instruction kind matched by this pattern. - Instruction::Kind kind; - /// Nested patterns to be matched. ArrayRef<NestedPattern> nestedPatterns; diff --git a/mlir/include/mlir/Analysis/SliceAnalysis.h b/mlir/include/mlir/Analysis/SliceAnalysis.h index c3cafbc8ae4..3adefdf8bb2 100644 --- a/mlir/include/mlir/Analysis/SliceAnalysis.h +++ b/mlir/include/mlir/Analysis/SliceAnalysis.h @@ -168,10 +168,10 @@ void getBackwardSlice( /// /____\ /// /// We want to iteratively apply `getSlice` to construct the whole -/// list of OperationInst that are reachable by (use|def)+ from inst. +/// list of Instruction that are reachable by (use|def)+ from inst. /// We want the resulting slice in topological order. /// Ideally we would like the ordering to be maintained in-place to avoid -/// copying OperationInst at each step. Keeping this ordering by construction +/// copying Instruction at each step. Keeping this ordering by construction /// seems very unclear, so we list invariants in the hope of seeing whether /// useful properties pop up. /// @@ -207,7 +207,7 @@ llvm::SetVector<Instruction *> getSlice( [](Instruction *) { return true; }); /// Multi-root DAG topological sort. -/// Performs a topological sort of the OperationInst in the `toSort` SetVector. +/// Performs a topological sort of the Instruction in the `toSort` SetVector. /// Returns a topologically sorted SetVector. llvm::SetVector<Instruction *> topologicalSort(const llvm::SetVector<Instruction *> &toSort); diff --git a/mlir/include/mlir/Analysis/Utils.h b/mlir/include/mlir/Analysis/Utils.h index 78baba6f7cf..65af6d7b1f2 100644 --- a/mlir/include/mlir/Analysis/Utils.h +++ b/mlir/include/mlir/Analysis/Utils.h @@ -38,7 +38,6 @@ template <typename T> class ConstOpPointer; class FlatAffineConstraints; class Instruction; class MemRefAccess; -using OperationInst = Instruction; template <typename T> class OpPointer; class Instruction; class Value; @@ -143,7 +142,7 @@ private: /// {memref = %A, write = false, {%i <= m0 <= %i + 7} } /// The last field is a 2-d FlatAffineConstraints symbolic in %i. /// -bool getMemRefRegion(OperationInst *opInst, unsigned loopDepth, +bool getMemRefRegion(Instruction *opInst, unsigned loopDepth, MemRefRegion *region); /// Returns the size of memref data in bytes if it's statically shaped, None @@ -196,8 +195,8 @@ bool getBackwardComputationSliceState(const MemRefAccess &srcAccess, // storage and redundant computation in several cases. // TODO(andydavis) Support computation slices with common surrounding loops. OpPointer<AffineForOp> -insertBackwardComputationSlice(OperationInst *srcOpInst, - OperationInst *dstOpInst, unsigned dstLoopDepth, +insertBackwardComputationSlice(Instruction *srcOpInst, Instruction *dstOpInst, + unsigned dstLoopDepth, ComputationSliceState *sliceState); Optional<int64_t> getMemoryFootprintBytes(ConstOpPointer<AffineForOp> forOp, diff --git a/mlir/include/mlir/Analysis/VectorAnalysis.h b/mlir/include/mlir/Analysis/VectorAnalysis.h index 7a5f05a690a..4982481bf6c 100644 --- a/mlir/include/mlir/Analysis/VectorAnalysis.h +++ b/mlir/include/mlir/Analysis/VectorAnalysis.h @@ -31,7 +31,6 @@ class FuncBuilder; class Instruction; class Location; class MemRefType; -using OperationInst = Instruction; class Value; class VectorType; @@ -123,7 +122,7 @@ shapeRatio(VectorType superVectorType, VectorType subVectorType); /// `%arg0[%c0, %c0]` into vector<128xf32> which needs a 1-D vector broadcast. /// AffineMap makePermutationMap( - OperationInst *opInst, + Instruction *opInst, const llvm::DenseMap<Instruction *, unsigned> &loopToVectorDim); namespace matcher { @@ -136,8 +135,7 @@ namespace matcher { /// TODO(ntv): this could all be much simpler if we added a bit that a vector /// type to mark that a vector is a strict super-vector but it still does not /// warrant adding even 1 extra bit in the IR for now. -bool operatesOnSuperVectors(const OperationInst &inst, - VectorType subVectorType); +bool operatesOnSuperVectors(const Instruction &inst, VectorType subVectorType); } // end namespace matcher } // end namespace mlir diff --git a/mlir/include/mlir/Dialect/Traits.h b/mlir/include/mlir/Dialect/Traits.h index 55d5ebe896a..3cbaf351905 100644 --- a/mlir/include/mlir/Dialect/Traits.h +++ b/mlir/include/mlir/Dialect/Traits.h @@ -32,7 +32,7 @@ namespace OpTrait { // corresponding trait classes. This avoids them being template // instantiated/duplicated. namespace impl { -bool verifyCompatibleOperandBroadcast(const OperationInst *op); +bool verifyCompatibleOperandBroadcast(const Instruction *op); } // namespace impl namespace util { @@ -54,7 +54,7 @@ template <typename ConcreteType> class BroadcastableTwoOperandsOneResult : public TraitBase<ConcreteType, BroadcastableTwoOperandsOneResult> { public: - static bool verifyTrait(const OperationInst *op) { + static bool verifyTrait(const Instruction *op) { return impl::verifyCompatibleOperandBroadcast(op); } }; diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h index e2e84b03494..d0982630a5a 100644 --- a/mlir/include/mlir/IR/Block.h +++ b/mlir/include/mlir/IR/Block.h @@ -108,7 +108,7 @@ public: } /// Returns the function that this block is part of, even if the block is - /// nested under an OperationInst or ForInst. + /// nested under an operation region. Function *getFunction(); const Function *getFunction() const { return const_cast<Block *>(this)->getFunction(); @@ -233,9 +233,9 @@ public: /// Get the terminator instruction of this block, or null if the block is /// malformed. - OperationInst *getTerminator(); + Instruction *getTerminator(); - const OperationInst *getTerminator() const { + const Instruction *getTerminator() const { return const_cast<Block *>(this)->getTerminator(); } @@ -363,7 +363,7 @@ private: namespace mlir { /// This class contains a list of basic blocks and has a notion of the object it -/// is part of - a Function or OperationInst or ForInst. +/// is part of - a Function or an operation region. class BlockList { public: explicit BlockList(Function *container); @@ -475,7 +475,7 @@ public: } private: - using BBUseIterator = ValueUseIterator<BlockOperand, OperationInst>; + using BBUseIterator = ValueUseIterator<BlockOperand, Instruction>; BBUseIterator bbUseIterator; }; diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index d59f0488e49..1d9421b909f 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -243,7 +243,7 @@ public: Block *getBlock() const { return block; } /// Creates an operation given the fields represented as an OperationState. - OperationInst *createOperation(const OperationState &state); + Instruction *createOperation(const OperationState &state); /// Create operation of specific op type at the current insertion point. template <typename OpTy, typename... Args> diff --git a/mlir/include/mlir/IR/BuiltinOps.h b/mlir/include/mlir/IR/BuiltinOps.h index 0230b1ab5ec..8be2dbe7b79 100644 --- a/mlir/include/mlir/IR/BuiltinOps.h +++ b/mlir/include/mlir/IR/BuiltinOps.h @@ -82,7 +82,7 @@ public: private: friend class Instruction; - explicit AffineApplyOp(const OperationInst *state) : Op(state) {} + explicit AffineApplyOp(const Instruction *state) : Op(state) {} }; /// The "br" operation represents a branch instruction in a CFG function. @@ -119,7 +119,7 @@ public: private: friend class Instruction; - explicit BranchOp(const OperationInst *state) : Op(state) {} + explicit BranchOp(const Instruction *state) : Op(state) {} }; /// The "cond_br" operation represents a conditional branch instruction in a @@ -258,7 +258,7 @@ private: } friend class Instruction; - explicit CondBranchOp(const OperationInst *state) : Op(state) {} + explicit CondBranchOp(const Instruction *state) : Op(state) {} }; /// The "constant" operation requires a single attribute named "value". @@ -287,7 +287,7 @@ public: protected: friend class Instruction; - explicit ConstantOp(const OperationInst *state) : Op(state) {} + explicit ConstantOp(const Instruction *state) : Op(state) {} }; /// This is a refinement of the "constant" op for the case where it is @@ -305,11 +305,11 @@ public: return getAttrOfType<FloatAttr>("value").getValue(); } - static bool isClassFor(const OperationInst *op); + static bool isClassFor(const Instruction *op); private: friend class Instruction; - explicit ConstantFloatOp(const OperationInst *state) : ConstantOp(state) {} + explicit ConstantFloatOp(const Instruction *state) : ConstantOp(state) {} }; /// This is a refinement of the "constant" op for the case where it is @@ -332,11 +332,11 @@ public: return getAttrOfType<IntegerAttr>("value").getInt(); } - static bool isClassFor(const OperationInst *op); + static bool isClassFor(const Instruction *op); private: friend class Instruction; - explicit ConstantIntOp(const OperationInst *state) : ConstantOp(state) {} + explicit ConstantIntOp(const Instruction *state) : ConstantOp(state) {} }; /// This is a refinement of the "constant" op for the case where it is @@ -353,11 +353,11 @@ public: return getAttrOfType<IntegerAttr>("value").getInt(); } - static bool isClassFor(const OperationInst *op); + static bool isClassFor(const Instruction *op); private: friend class Instruction; - explicit ConstantIndexOp(const OperationInst *state) : ConstantOp(state) {} + explicit ConstantIndexOp(const Instruction *state) : ConstantOp(state) {} }; /// The "return" operation represents a return instruction within a function. @@ -384,12 +384,12 @@ public: private: friend class Instruction; - explicit ReturnOp(const OperationInst *state) : Op(state) {} + explicit ReturnOp(const Instruction *state) : Op(state) {} }; /// Prints dimension and symbol list. -void printDimAndSymbolList(OperationInst::const_operand_iterator begin, - OperationInst::const_operand_iterator end, +void printDimAndSymbolList(Instruction::const_operand_iterator begin, + Instruction::const_operand_iterator end, unsigned numDims, OpAsmPrinter *p); /// Parses dimension and symbol list and returns true if parsing failed. diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h index e1ef802fd8e..ba8c586725a 100644 --- a/mlir/include/mlir/IR/Dialect.h +++ b/mlir/include/mlir/IR/Dialect.h @@ -30,7 +30,7 @@ class IntegerSet; class Type; using DialectConstantFoldHook = std::function<bool( - const OperationInst *, ArrayRef<Attribute>, SmallVectorImpl<Attribute> &)>; + const Instruction *, ArrayRef<Attribute>, SmallVectorImpl<Attribute> &)>; using DialectTypeParserHook = std::function<Type(StringRef, Location, MLIRContext *)>; using DialectTypePrinterHook = std::function<void(Type, raw_ostream &)>; @@ -56,7 +56,7 @@ public: /// and fills in the `results` vector. If not, this returns true and /// `results` is unspecified. DialectConstantFoldHook constantFoldHook = - [](const OperationInst *op, ArrayRef<Attribute> operands, + [](const Instruction *op, ArrayRef<Attribute> operands, SmallVectorImpl<Attribute> &results) { return true; }; /// Registered parsing/printing hooks for types registered to the dialect. diff --git a/mlir/include/mlir/IR/Function.h b/mlir/include/mlir/IR/Function.h index 9d55e64ec76..3876d750a28 100644 --- a/mlir/include/mlir/IR/Function.h +++ b/mlir/include/mlir/IR/Function.h @@ -118,12 +118,12 @@ public: /// Walk the instructions in the function in preorder, calling the callback /// for each instruction or operation. void walkInsts(std::function<void(Instruction *)> callback); - void walkOps(std::function<void(OperationInst *)> callback); + void walkOps(std::function<void(Instruction *)> callback); /// Walk the instructions in the function in postorder, calling the callback /// for each instruction or operation. void walkInstsPostOrder(std::function<void(Instruction *)> callback); - void walkOpsPostOrder(std::function<void(OperationInst *)> callback); + void walkOpsPostOrder(std::function<void(Instruction *)> callback); //===--------------------------------------------------------------------===// // Arguments diff --git a/mlir/include/mlir/IR/InstVisitor.h b/mlir/include/mlir/IR/InstVisitor.h index bb784b59b47..7b74c69ceef 100644 --- a/mlir/include/mlir/IR/InstVisitor.h +++ b/mlir/include/mlir/IR/InstVisitor.h @@ -44,8 +44,8 @@ // lc.walk(function); // numLoops = lc.numLoops; // -// There are 'visit' methods for OperationInst, ForInst, and -// Function, which recursively process all contained instructions. +// There are 'visit' methods for Instruction and Function, which recursively +// process all contained instructions. // // Note that if you don't implement visitXXX for some instruction type, // the visitXXX method for Instruction superclass will be invoked. @@ -55,9 +55,7 @@ // an implementation of every visit<#Instruction>(InstType *). // // Note that these classes are specifically designed as a template to avoid -// virtual function call overhead. Defining and using a InstVisitor is just -// as efficient as having your own switch instruction over the instruction -// opcode. +// virtual function call overhead. // //===----------------------------------------------------------------------===// @@ -81,12 +79,7 @@ public: RetTy visit(Instruction *s) { static_assert(std::is_base_of<InstVisitor, SubClass>::value, "Must pass the derived type to this template!"); - - switch (s->getKind()) { - case Instruction::Kind::OperationInst: - return static_cast<SubClass *>(this)->visitOperationInst( - cast<OperationInst>(s)); - } + return static_cast<SubClass *>(this)->visitOperationInst(s); } //===--------------------------------------------------------------------===// @@ -99,7 +92,7 @@ public: // When visiting a for inst, if inst, or an operation inst directly, these // methods get called to indicate when transitioning into a new unit. - void visitOperationInst(OperationInst *opInst) {} + void visitOperationInst(Instruction *opInst) {} }; /// Base class for instruction walkers. A walker can traverse depth first in @@ -134,14 +127,14 @@ public: static_cast<SubClass *>(this)->walkPostOrder(it->begin(), it->end()); } - void walkOpInst(OperationInst *opInst) { + void walkOpInst(Instruction *opInst) { static_cast<SubClass *>(this)->visitOperationInst(opInst); for (auto &blockList : opInst->getBlockLists()) for (auto &block : blockList) static_cast<SubClass *>(this)->walk(block.begin(), block.end()); } - void walkOpInstPostOrder(OperationInst *opInst) { + void walkOpInstPostOrder(Instruction *opInst) { for (auto &blockList : opInst->getBlockLists()) for (auto &block : blockList) static_cast<SubClass *>(this)->walkPostOrder(block.begin(), @@ -155,11 +148,7 @@ public: "Must pass the derived type to this template!"); static_cast<SubClass *>(this)->visitInstruction(s); - - switch (s->getKind()) { - case Instruction::Kind::OperationInst: - return static_cast<SubClass *>(this)->walkOpInst(cast<OperationInst>(s)); - } + return static_cast<SubClass *>(this)->walkOpInst(s); } // Function to walk a instruction in post order DFS. @@ -167,12 +156,7 @@ public: static_assert(std::is_base_of<InstWalker, SubClass>::value, "Must pass the derived type to this template!"); static_cast<SubClass *>(this)->visitInstruction(s); - - switch (s->getKind()) { - case Instruction::Kind::OperationInst: - return static_cast<SubClass *>(this)->walkOpInstPostOrder( - cast<OperationInst>(s)); - } + return static_cast<SubClass *>(this)->walkOpInstPostOrder(s); } //===--------------------------------------------------------------------===// @@ -186,7 +170,7 @@ public: // called. These are typically O(1) complexity and shouldn't be recursively // processing their descendants in some way. When using RetTy, all of these // need to be overridden. - void visitOperationInst(OperationInst *opInst) {} + void visitOperationInst(Instruction *opInst) {} void visitInstruction(Instruction *inst) {} }; diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h index 66396553e5a..6f35ac24180 100644 --- a/mlir/include/mlir/IR/MLIRContext.h +++ b/mlir/include/mlir/IR/MLIRContext.h @@ -65,7 +65,7 @@ public: // Diagnostic handler registration and use. MLIR supports the ability for the // IR to carry arbitrary metadata about operation location information. If an // problem is detected by the compiler, it can invoke the emitError / - // emitWarning / emitNote method on an OperationInst and have it get reported + // emitWarning / emitNote method on an Instruction and have it get reported // through this interface. // // Tools using MLIR are encouraged to register error handlers and define a @@ -85,7 +85,7 @@ public: /// Emit a diagnostic using the registered issue handle if present, or with /// the default behavior if not. The MLIR compiler should not generally - /// interact with this, it should use methods on OperationInst instead. + /// interact with this, it should use methods on Instruction instead. void emitDiagnostic(Location location, const Twine &message, DiagnosticKind kind) const; diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h index 3fd13a4b78f..d162a6aff46 100644 --- a/mlir/include/mlir/IR/Matchers.h +++ b/mlir/include/mlir/IR/Matchers.h @@ -68,7 +68,7 @@ struct constant_int_op_binder { /// Creates a matcher instance that binds the value to bv if match succeeds. constant_int_op_binder(IntegerAttr::ValueType *bv) : bind_value(bv) {} - bool match(OperationInst *op) { + bool match(Instruction *op) { if (auto constOp = op->dyn_cast<ConstantOp>()) { auto type = constOp->getResult()->getType(); auto attr = constOp->getAttr("value"); @@ -90,7 +90,7 @@ struct constant_int_op_binder { // The matcher that matches a given target constant scalar / vector splat / // tensor splat integer value. template <int64_t TargetValue> struct constant_int_value_matcher { - bool match(OperationInst *op) { + bool match(Instruction *op) { APInt value; return constant_int_op_binder(&value).match(op) && TargetValue == value; @@ -99,7 +99,7 @@ template <int64_t TargetValue> struct constant_int_value_matcher { /// The matcher that matches a certain kind of op. template <typename OpClass> struct op_matcher { - bool match(OperationInst *op) { return op->isa<OpClass>(); } + bool match(Instruction *op) { return op->isa<OpClass>(); } }; } // end namespace detail diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index 0de5559a2a7..eaf695fb799 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -54,12 +54,12 @@ template <typename OpType> struct IsSingleResult { OpType *, OpTrait::OneResult<typename OpType::ConcreteOpType> *>::value; }; -/// This pointer represents a notional "OperationInst*" but where the actual +/// This pointer represents a notional "Instruction*" but where the actual /// storage of the pointer is maintained in the templated "OpType" class. template <typename OpType> class OpPointer { public: - explicit OpPointer() : value(OperationInst::getNull<OpType>().value) {} + explicit OpPointer() : value(Instruction::getNull<OpType>().value) {} explicit OpPointer(OpType value) : value(value) {} OpType &operator*() { return value; } @@ -74,7 +74,7 @@ public: bool operator!=(OpPointer rhs) const { return !(*this == rhs); } /// OpPointer can be implicitly converted to OpType*. - /// Return `nullptr` if there is no associated OperationInst*. + /// Return `nullptr` if there is no associated Instruction*. operator OpType *() { if (!value.getInstruction()) return nullptr; @@ -97,12 +97,12 @@ private: friend class ConstOpPointer<OpType>; }; -/// This pointer represents a notional "const OperationInst*" but where the +/// This pointer represents a notional "const Instruction*" but where the /// actual storage of the pointer is maintained in the templated "OpType" class. template <typename OpType> class ConstOpPointer { public: - explicit ConstOpPointer() : value(OperationInst::getNull<OpType>().value) {} + explicit ConstOpPointer() : value(Instruction::getNull<OpType>().value) {} explicit ConstOpPointer(OpType value) : value(value) {} ConstOpPointer(OpPointer<OpType> pointer) : value(pointer.value) {} @@ -119,7 +119,7 @@ public: bool operator!=(ConstOpPointer rhs) const { return !(*this == rhs); } /// ConstOpPointer can always be implicitly converted to const OpType*. - /// Return `nullptr` if there is no associated OperationInst*. + /// Return `nullptr` if there is no associated Instruction*. operator const OpType *() const { if (!value.getInstruction()) return nullptr; @@ -151,8 +151,8 @@ private: class OpState { public: /// Return the operation that this refers to. - const OperationInst *getInstruction() const { return state; } - OperationInst *getInstruction() { return state; } + const Instruction *getInstruction() const { return state; } + Instruction *getInstruction() { return state; } /// The source location the operation was defined or derived from. Location getLoc() const { return state->getLoc(); } @@ -221,11 +221,11 @@ protected: /// Mutability management is handled by the OpWrapper/OpConstWrapper classes, /// so we can cast it away here. - explicit OpState(const OperationInst *state) - : state(const_cast<OperationInst *>(state)) {} + explicit OpState(const Instruction *state) + : state(const_cast<Instruction *>(state)) {} private: - OperationInst *state; + Instruction *state; }; /// This template defines the constantFoldHook and foldHook as used by @@ -238,7 +238,7 @@ class FoldingHook { public: /// This is an implementation detail of the constant folder hook for /// AbstractOperation. - static bool constantFoldHook(const OperationInst *op, + static bool constantFoldHook(const Instruction *op, ArrayRef<Attribute> operands, SmallVectorImpl<Attribute> &results) { return op->cast<ConcreteType>()->constantFold(operands, results, @@ -261,7 +261,7 @@ public: } /// This is an implementation detail of the folder hook for AbstractOperation. - static bool foldHook(OperationInst *op, SmallVectorImpl<Value *> &results) { + static bool foldHook(Instruction *op, SmallVectorImpl<Value *> &results) { return op->cast<ConcreteType>()->fold(results); } @@ -300,7 +300,7 @@ class FoldingHook<ConcreteType, isSingleResult, public: /// This is an implementation detail of the constant folder hook for /// AbstractOperation. - static bool constantFoldHook(const OperationInst *op, + static bool constantFoldHook(const Instruction *op, ArrayRef<Attribute> operands, SmallVectorImpl<Attribute> &results) { auto result = @@ -327,7 +327,7 @@ public: } /// This is an implementation detail of the folder hook for AbstractOperation. - static bool foldHook(OperationInst *op, SmallVectorImpl<Value *> &results) { + static bool foldHook(Instruction *op, SmallVectorImpl<Value *> &results) { auto *result = op->cast<ConcreteType>()->fold(); if (!result) return true; @@ -362,7 +362,7 @@ public: }; //===----------------------------------------------------------------------===// -// OperationInst Trait Types +// Instruction Trait Types //===----------------------------------------------------------------------===// namespace OpTrait { @@ -371,22 +371,22 @@ namespace OpTrait { // corresponding trait classes. This avoids them being template // instantiated/duplicated. namespace impl { -bool verifyZeroOperands(const OperationInst *op); -bool verifyOneOperand(const OperationInst *op); -bool verifyNOperands(const OperationInst *op, unsigned numOperands); -bool verifyAtLeastNOperands(const OperationInst *op, unsigned numOperands); -bool verifyOperandsAreIntegerLike(const OperationInst *op); -bool verifySameTypeOperands(const OperationInst *op); -bool verifyZeroResult(const OperationInst *op); -bool verifyOneResult(const OperationInst *op); -bool verifyNResults(const OperationInst *op, unsigned numOperands); -bool verifyAtLeastNResults(const OperationInst *op, unsigned numOperands); -bool verifySameOperandsAndResultShape(const OperationInst *op); -bool verifySameOperandsAndResultType(const OperationInst *op); -bool verifyResultsAreBoolLike(const OperationInst *op); -bool verifyResultsAreFloatLike(const OperationInst *op); -bool verifyResultsAreIntegerLike(const OperationInst *op); -bool verifyIsTerminator(const OperationInst *op); +bool verifyZeroOperands(const Instruction *op); +bool verifyOneOperand(const Instruction *op); +bool verifyNOperands(const Instruction *op, unsigned numOperands); +bool verifyAtLeastNOperands(const Instruction *op, unsigned numOperands); +bool verifyOperandsAreIntegerLike(const Instruction *op); +bool verifySameTypeOperands(const Instruction *op); +bool verifyZeroResult(const Instruction *op); +bool verifyOneResult(const Instruction *op); +bool verifyNResults(const Instruction *op, unsigned numOperands); +bool verifyAtLeastNResults(const Instruction *op, unsigned numOperands); +bool verifySameOperandsAndResultShape(const Instruction *op); +bool verifySameOperandsAndResultType(const Instruction *op); +bool verifyResultsAreBoolLike(const Instruction *op); +bool verifyResultsAreFloatLike(const Instruction *op); +bool verifyResultsAreIntegerLike(const Instruction *op); +bool verifyIsTerminator(const Instruction *op); } // namespace impl /// Helper class for implementing traits. Clients are not expected to interact @@ -394,8 +394,8 @@ bool verifyIsTerminator(const OperationInst *op); template <typename ConcreteType, template <typename> class TraitType> class TraitBase { protected: - /// Return the ultimate OperationInst being worked on. - OperationInst *getInstruction() { + /// Return the ultimate Instruction being worked on. + Instruction *getInstruction() { // We have to cast up to the trait type, then to the concrete type, then to // the BaseState class in explicit hops because the concrete type will // multiply derive from the (content free) TraitBase class, and we need to @@ -405,13 +405,13 @@ protected: auto *base = static_cast<OpState *>(concrete); return base->getInstruction(); } - const OperationInst *getInstruction() const { + const Instruction *getInstruction() const { return const_cast<TraitBase *>(this)->getInstruction(); } /// Provide default implementations of trait hooks. This allows traits to /// provide exactly the overrides they care about. - static bool verifyTrait(const OperationInst *op) { return false; } + static bool verifyTrait(const Instruction *op) { return false; } static AbstractOperation::OperationProperties getTraitProperties() { return 0; } @@ -422,7 +422,7 @@ protected: template <typename ConcreteType> class ZeroOperands : public TraitBase<ConcreteType, ZeroOperands> { public: - static bool verifyTrait(const OperationInst *op) { + static bool verifyTrait(const Instruction *op) { return impl::verifyZeroOperands(op); } @@ -447,7 +447,7 @@ public: this->getInstruction()->setOperand(0, value); } - static bool verifyTrait(const OperationInst *op) { + static bool verifyTrait(const Instruction *op) { return impl::verifyOneOperand(op); } }; @@ -474,7 +474,7 @@ public: this->getInstruction()->setOperand(i, value); } - static bool verifyTrait(const OperationInst *op) { + static bool verifyTrait(const Instruction *op) { return impl::verifyNOperands(op, N); } }; @@ -506,7 +506,7 @@ public: } // Support non-const operand iteration. - using operand_iterator = OperationInst::operand_iterator; + using operand_iterator = Instruction::operand_iterator; operand_iterator operand_begin() { return this->getInstruction()->operand_begin(); } @@ -518,7 +518,7 @@ public: } // Support const operand iteration. - using const_operand_iterator = OperationInst::const_operand_iterator; + using const_operand_iterator = Instruction::const_operand_iterator; const_operand_iterator operand_begin() const { return this->getInstruction()->operand_begin(); } @@ -529,7 +529,7 @@ public: return this->getInstruction()->getOperands(); } - static bool verifyTrait(const OperationInst *op) { + static bool verifyTrait(const Instruction *op) { return impl::verifyAtLeastNOperands(op, N); } }; @@ -557,7 +557,7 @@ public: } // Support non-const operand iteration. - using operand_iterator = OperationInst::operand_iterator; + using operand_iterator = Instruction::operand_iterator; operand_iterator operand_begin() { return this->getInstruction()->operand_begin(); } @@ -569,7 +569,7 @@ public: } // Support const operand iteration. - using const_operand_iterator = OperationInst::const_operand_iterator; + using const_operand_iterator = Instruction::const_operand_iterator; const_operand_iterator operand_begin() const { return this->getInstruction()->operand_begin(); } @@ -586,7 +586,7 @@ public: template <typename ConcreteType> class ZeroResult : public TraitBase<ConcreteType, ZeroResult> { public: - static bool verifyTrait(const OperationInst *op) { + static bool verifyTrait(const Instruction *op) { return impl::verifyZeroResult(op); } }; @@ -610,7 +610,7 @@ public: getResult()->replaceAllUsesWith(newValue); } - static bool verifyTrait(const OperationInst *op) { + static bool verifyTrait(const Instruction *op) { return impl::verifyOneResult(op); } }; @@ -637,7 +637,7 @@ public: Type getType(unsigned i) const { return getResult(i)->getType(); } - static bool verifyTrait(const OperationInst *op) { + static bool verifyTrait(const Instruction *op) { return impl::verifyNResults(op, N); } }; @@ -663,7 +663,7 @@ public: Type getType(unsigned i) const { return getResult(i)->getType(); } - static bool verifyTrait(const OperationInst *op) { + static bool verifyTrait(const Instruction *op) { return impl::verifyAtLeastNResults(op, N); } }; @@ -689,7 +689,7 @@ public: } // Support non-const result iteration. - using result_iterator = OperationInst::result_iterator; + using result_iterator = Instruction::result_iterator; result_iterator result_begin() { return this->getInstruction()->result_begin(); } @@ -699,7 +699,7 @@ public: } // Support const result iteration. - using const_result_iterator = OperationInst::const_result_iterator; + using const_result_iterator = Instruction::const_result_iterator; const_result_iterator result_begin() const { return this->getInstruction()->result_begin(); } @@ -718,7 +718,7 @@ template <typename ConcreteType> class SameOperandsAndResultShape : public TraitBase<ConcreteType, SameOperandsAndResultShape> { public: - static bool verifyTrait(const OperationInst *op) { + static bool verifyTrait(const Instruction *op) { return impl::verifySameOperandsAndResultShape(op); } }; @@ -733,7 +733,7 @@ template <typename ConcreteType> class SameOperandsAndResultType : public TraitBase<ConcreteType, SameOperandsAndResultType> { public: - static bool verifyTrait(const OperationInst *op) { + static bool verifyTrait(const Instruction *op) { return impl::verifySameOperandsAndResultType(op); } }; @@ -743,7 +743,7 @@ public: template <typename ConcreteType> class ResultsAreBoolLike : public TraitBase<ConcreteType, ResultsAreBoolLike> { public: - static bool verifyTrait(const OperationInst *op) { + static bool verifyTrait(const Instruction *op) { return impl::verifyResultsAreBoolLike(op); } }; @@ -754,7 +754,7 @@ template <typename ConcreteType> class ResultsAreFloatLike : public TraitBase<ConcreteType, ResultsAreFloatLike> { public: - static bool verifyTrait(const OperationInst *op) { + static bool verifyTrait(const Instruction *op) { return impl::verifyResultsAreFloatLike(op); } }; @@ -765,7 +765,7 @@ template <typename ConcreteType> class ResultsAreIntegerLike : public TraitBase<ConcreteType, ResultsAreIntegerLike> { public: - static bool verifyTrait(const OperationInst *op) { + static bool verifyTrait(const Instruction *op) { return impl::verifyResultsAreIntegerLike(op); } }; @@ -796,7 +796,7 @@ template <typename ConcreteType> class OperandsAreIntegerLike : public TraitBase<ConcreteType, OperandsAreIntegerLike> { public: - static bool verifyTrait(const OperationInst *op) { + static bool verifyTrait(const Instruction *op) { return impl::verifyOperandsAreIntegerLike(op); } }; @@ -806,7 +806,7 @@ public: template <typename ConcreteType> class SameTypeOperands : public TraitBase<ConcreteType, SameTypeOperands> { public: - static bool verifyTrait(const OperationInst *op) { + static bool verifyTrait(const Instruction *op) { return impl::verifySameTypeOperands(op); } }; @@ -819,7 +819,7 @@ public: return static_cast<AbstractOperation::OperationProperties>( OperationProperty::Terminator); } - static bool verifyTrait(const OperationInst *op) { + static bool verifyTrait(const Instruction *op) { return impl::verifyIsTerminator(op); } @@ -852,7 +852,7 @@ public: } // end namespace OpTrait //===----------------------------------------------------------------------===// -// OperationInst Definition classes +// Instruction Definition classes //===----------------------------------------------------------------------===// /// This provides public APIs that all operations should have. The template @@ -867,16 +867,16 @@ class Op : public OpState, Traits<ConcreteType>...>::value> { public: /// Return the operation that this refers to. - const OperationInst *getInstruction() const { + const Instruction *getInstruction() const { return OpState::getInstruction(); } - OperationInst *getInstruction() { return OpState::getInstruction(); } + Instruction *getInstruction() { return OpState::getInstruction(); } /// Return true if this "op class" can match against the specified operation. /// This hook can be overridden with a more specific implementation in /// the subclass of Base. /// - static bool isClassFor(const OperationInst *op) { + static bool isClassFor(const Instruction *op) { return op->getName().getStringRef() == ConcreteType::getOperationName(); } @@ -890,7 +890,7 @@ public: /// This is the hook used by the AsmPrinter to emit this to the .mlir file. /// Op implementations should provide a print method. - static void printAssembly(const OperationInst *op, OpAsmPrinter *p) { + static void printAssembly(const Instruction *op, OpAsmPrinter *p) { auto opPointer = op->dyn_cast<ConcreteType>(); assert(opPointer && "op's name does not match name of concrete type instantiated with"); @@ -904,7 +904,7 @@ public: /// /// On success this returns false; on failure it emits an error to the /// diagnostic subsystem and returns true. - static bool verifyInvariants(const OperationInst *op) { + static bool verifyInvariants(const Instruction *op) { return BaseVerifier<Traits<ConcreteType>...>::verifyTrait(op) || op->cast<ConcreteType>()->verify(); } @@ -922,26 +922,26 @@ public: using ConcreteOpType = ConcreteType; protected: - explicit Op(const OperationInst *state) : OpState(state) {} + explicit Op(const Instruction *state) : OpState(state) {} private: template <typename... Types> struct BaseVerifier; template <typename First, typename... Rest> struct BaseVerifier<First, Rest...> { - static bool verifyTrait(const OperationInst *op) { + static bool verifyTrait(const Instruction *op) { return First::verifyTrait(op) || BaseVerifier<Rest...>::verifyTrait(op); } }; template <typename First> struct BaseVerifier<First> { - static bool verifyTrait(const OperationInst *op) { + static bool verifyTrait(const Instruction *op) { return First::verifyTrait(op); } }; template <> struct BaseVerifier<> { - static bool verifyTrait(const OperationInst *op) { return false; } + static bool verifyTrait(const Instruction *op) { return false; } }; template <typename... Types> struct BaseProperties; @@ -976,7 +976,7 @@ bool parseBinaryOp(OpAsmParser *parser, OperationState *result); // Prints the given binary `op` in custom assembly form if both the two operands // and the result have the same time. Otherwise, prints the generic assembly // form. -void printBinaryOp(const OperationInst *op, OpAsmPrinter *p); +void printBinaryOp(const Instruction *op, OpAsmPrinter *p); } // namespace impl // These functions are out-of-line implementations of the methods in CastOp, @@ -985,7 +985,7 @@ namespace impl { void buildCastOp(Builder *builder, OperationState *result, Value *source, Type destType); bool parseCastOp(OpAsmParser *parser, OperationState *result); -void printCastOp(const OperationInst *op, OpAsmPrinter *p); +void printCastOp(const Instruction *op, OpAsmPrinter *p); } // namespace impl /// This template is used for operations that are cast operations, that have a @@ -1010,7 +1010,7 @@ public: } protected: - explicit CastOp(const OperationInst *state) + explicit CastOp(const Instruction *state) : Op<ConcreteType, OpTrait::OneOperand, OpTrait::OneResult, OpTrait::HasNoSideEffect, Traits...>(state) {} }; diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h index 4e7596498e7..9d58f635ecd 100644 --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -75,7 +75,7 @@ public: /// Print a successor, and use list, of a terminator operation given the /// terminator and the successor index. - virtual void printSuccessorAndUseList(const OperationInst *term, + virtual void printSuccessorAndUseList(const Instruction *term, unsigned index) = 0; /// If the specified operation has attributes, print out an attribute @@ -87,7 +87,7 @@ public: ArrayRef<const char *> elidedAttrs = {}) = 0; /// Print the entire operation with the default generic assembly form. - virtual void printGenericOp(const OperationInst *op) = 0; + virtual void printGenericOp(const Instruction *op) = 0; /// Prints a block list. virtual void printBlockList(const BlockList &blocks, diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h index b5a4ab9b0e6..aab0137af5a 100644 --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -15,7 +15,7 @@ // limitations under the License. // ============================================================================= // -// This file defines a number of support types that OperationInst and related +// This file defines a number of support types that Instruction and related // classes build on top of. // //===----------------------------------------------------------------------===// @@ -34,7 +34,6 @@ namespace mlir { class Block; class Dialect; class Instruction; -using OperationInst = Instruction; class OperationState; class OpAsmParser; class OpAsmParserResult; @@ -78,24 +77,23 @@ public: Dialect &dialect; /// Return true if this "op class" can match against the specified operation. - bool (&isClassFor)(const OperationInst *op); + bool (&isClassFor)(const Instruction *op); /// Use the specified object to parse this ops custom assembly format. bool (&parseAssembly)(OpAsmParser *parser, OperationState *result); /// This hook implements the AsmPrinter for this operation. - void (&printAssembly)(const OperationInst *op, OpAsmPrinter *p); + void (&printAssembly)(const Instruction *op, OpAsmPrinter *p); /// This hook implements the verifier for this operation. It should emits an /// error message and returns true if a problem is detected, or returns false /// if everything is ok. - bool (&verifyInvariants)(const OperationInst *op); + bool (&verifyInvariants)(const Instruction *op); /// This hook implements a constant folder for this operation. It returns /// true if folding failed, or returns false and fills in `results` on /// success. - bool (&constantFoldHook)(const OperationInst *op, - ArrayRef<Attribute> operands, + bool (&constantFoldHook)(const Instruction *op, ArrayRef<Attribute> operands, SmallVectorImpl<Attribute> &results); /// This hook implements a generalized folder for this operation. Operations @@ -118,7 +116,7 @@ public: /// "x+0 -> x", "min(x,y,x,z) -> min(x,y,z)", "x+y-x -> y", etc), but does /// not allow for canonicalizations that need to introduce new operations, not /// even constants (e.g. "x-x -> 0" cannot be expressed). - bool (&foldHook)(OperationInst *op, SmallVectorImpl<Value *> &results); + bool (&foldHook)(Instruction *op, SmallVectorImpl<Value *> &results); /// This hook returns any canonicalization pattern rewrites that the operation /// supports, for use by the canonicalization pass. @@ -147,14 +145,14 @@ public: private: AbstractOperation( StringRef name, Dialect &dialect, OperationProperties opProperties, - bool (&isClassFor)(const OperationInst *op), + bool (&isClassFor)(const Instruction *op), bool (&parseAssembly)(OpAsmParser *parser, OperationState *result), - void (&printAssembly)(const OperationInst *op, OpAsmPrinter *p), - bool (&verifyInvariants)(const OperationInst *op), - bool (&constantFoldHook)(const OperationInst *op, + void (&printAssembly)(const Instruction *op, OpAsmPrinter *p), + bool (&verifyInvariants)(const Instruction *op), + bool (&constantFoldHook)(const Instruction *op, ArrayRef<Attribute> operands, SmallVectorImpl<Attribute> &results), - bool (&foldHook)(OperationInst *op, SmallVectorImpl<Value *> &results), + bool (&foldHook)(Instruction *op, SmallVectorImpl<Value *> &results), void (&getCanonicalizationPatterns)(OwningRewritePatternList &results, MLIRContext *context)) : name(name), dialect(dialect), isClassFor(isClassFor), diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index aa5a14e75d6..b9cc12460ff 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -108,7 +108,7 @@ public: /// returns a None value. On success it a (possibly null) pattern-specific /// state wrapped in a Some. This state is passed back into its rewrite /// function if this match is selected. - virtual PatternMatchResult match(OperationInst *op) const = 0; + virtual PatternMatchResult match(Instruction *op) const = 0; virtual ~Pattern() {} @@ -148,7 +148,7 @@ public: /// rewriter. If an unexpected error is encountered (an internal /// compiler error), it is emitted through the normal MLIR diagnostic /// hooks and the IR is left in a valid state. - virtual void rewrite(OperationInst *op, std::unique_ptr<PatternState> state, + virtual void rewrite(Instruction *op, std::unique_ptr<PatternState> state, PatternRewriter &rewriter) const; /// Rewrite the IR rooted at the specified operation with the result of @@ -156,7 +156,7 @@ public: /// builder. If an unexpected error is encountered (an internal /// compiler error), it is emitted through the normal MLIR diagnostic /// hooks and the IR is left in a valid state. - virtual void rewrite(OperationInst *op, PatternRewriter &rewriter) const; + virtual void rewrite(Instruction *op, PatternRewriter &rewriter) const; protected: /// Patterns must specify the root operation name they match against, and can @@ -222,13 +222,13 @@ public: /// clients can specify a list of other nodes that this replacement may make /// (perhaps transitively) dead. If any of those values are dead, this will /// remove them as well. - void replaceOp(OperationInst *op, ArrayRef<Value *> newValues, + void replaceOp(Instruction *op, ArrayRef<Value *> newValues, ArrayRef<Value *> valuesToRemoveIfDead = {}); /// Replaces the result op with a new op that is created without verification. /// The result values of the two ops must be the same types. template <typename OpTy, typename... Args> - void replaceOpWithNewOp(OperationInst *op, Args... args) { + void replaceOpWithNewOp(Instruction *op, Args... args) { auto newOp = create<OpTy>(op->getLoc(), args...); replaceOpWithResultsOfAnotherOp(op, newOp->getInstruction(), {}); } @@ -237,7 +237,7 @@ public: /// The result values of the two ops must be the same types. This allows /// specifying a list of ops that may be removed if dead. template <typename OpTy, typename... Args> - void replaceOpWithNewOp(OperationInst *op, + void replaceOpWithNewOp(Instruction *op, ArrayRef<Value *> valuesToRemoveIfDead, Args... args) { auto newOp = create<OpTy>(op->getLoc(), args...); @@ -253,7 +253,7 @@ public: /// The valuesToRemoveIfDead list is an optional list of values that the /// rewriter should remove if they are dead at this point. /// - void updatedRootInPlace(OperationInst *op, + void updatedRootInPlace(Instruction *op, ArrayRef<Value *> valuesToRemoveIfDead = {}); protected: @@ -265,26 +265,26 @@ protected: /// This is implemented to create the specified operations and serves as a /// notification hook for rewriters that want to know about new operations. - virtual OperationInst *createOperation(const OperationState &state) = 0; + virtual Instruction *createOperation(const OperationState &state) = 0; /// Notify the pattern rewriter that the specified operation has been mutated /// in place. This is called after the mutation is done. - virtual void notifyRootUpdated(OperationInst *op) {} + virtual void notifyRootUpdated(Instruction *op) {} /// Notify the pattern rewriter that the specified operation is about to be /// replaced with another set of operations. This is called before the uses /// of the operation have been changed. - virtual void notifyRootReplaced(OperationInst *op) {} + virtual void notifyRootReplaced(Instruction *op) {} /// This is called on an operation that a pattern match is removing, right /// before the operation is deleted. At this point, the operation has zero /// uses. - virtual void notifyOperationRemoved(OperationInst *op) {} + virtual void notifyOperationRemoved(Instruction *op) {} private: /// op and newOp are known to have the same number of results, replace the /// uses of op with uses of newOp - void replaceOpWithResultsOfAnotherOp(OperationInst *op, OperationInst *newOp, + void replaceOpWithResultsOfAnotherOp(Instruction *op, Instruction *newOp, ArrayRef<Value *> valuesToRemoveIfDead); }; @@ -317,7 +317,7 @@ public: /// Find the highest benefit pattern available in the pattern set for the DAG /// rooted at the specified node. This returns the pattern (and any state it /// needs) if found, or null if there are no matches. - MatchResult findMatch(OperationInst *op); + MatchResult findMatch(Instruction *op); private: PatternMatcher(const PatternMatcher &) = delete; diff --git a/mlir/include/mlir/IR/UseDefLists.h b/mlir/include/mlir/IR/UseDefLists.h index a80b0cbf788..e726d9c7ae3 100644 --- a/mlir/include/mlir/IR/UseDefLists.h +++ b/mlir/include/mlir/IR/UseDefLists.h @@ -95,8 +95,7 @@ public: insertIntoCurrent(); } - /// Return the owner of this operand, for example, the OperationInst that - /// contains an InstOperand. + /// Return the owner of this operand. Instruction *getOwner() { return owner; } const Instruction *getOwner() const { return owner; } diff --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h index 8ecfa826243..465846d026c 100644 --- a/mlir/include/mlir/IR/Value.h +++ b/mlir/include/mlir/IR/Value.h @@ -30,7 +30,6 @@ namespace mlir { class Block; class Function; class Instruction; -using OperationInst = Instruction; class Value; /// Operands contain a Value. @@ -79,8 +78,8 @@ public: /// If this value is the result of an operation, return the instruction /// that defines it. - OperationInst *getDefiningInst(); - const OperationInst *getDefiningInst() const { + Instruction *getDefiningInst(); + const Instruction *getDefiningInst() const { return const_cast<Value *>(this)->getDefiningInst(); } @@ -157,15 +156,15 @@ private: /// This is a value defined by a result of an operation instruction. class InstResult : public Value { public: - InstResult(Type type, OperationInst *owner) + InstResult(Type type, Instruction *owner) : Value(Value::Kind::InstResult, type), owner(owner) {} static bool classof(const Value *value) { return value->getKind() == Kind::InstResult; } - OperationInst *getOwner() { return owner; } - const OperationInst *getOwner() const { return owner; } + Instruction *getOwner() { return owner; } + const Instruction *getOwner() const { return owner; } /// Returns the number of this result. unsigned getResultNumber() const; @@ -174,7 +173,7 @@ private: /// The owner of this operand. /// TODO: can encode this more efficiently to avoid the space hit of this /// through bitpacking shenanigans. - OperationInst *const owner; + Instruction *const owner; }; /// This is a helper template used to implement an iterator that contains a diff --git a/mlir/include/mlir/StandardOps/StandardOps.h b/mlir/include/mlir/StandardOps/StandardOps.h index 42838f7f6ce..756d903bc33 100644 --- a/mlir/include/mlir/StandardOps/StandardOps.h +++ b/mlir/include/mlir/StandardOps/StandardOps.h @@ -80,7 +80,7 @@ public: private: friend class Instruction; - explicit AllocOp(const OperationInst *state) : Op(state) {} + explicit AllocOp(const Instruction *state) : Op(state) {} }; /// The "call" operation represents a direct call to a function. The operands @@ -122,7 +122,7 @@ public: protected: friend class Instruction; - explicit CallOp(const OperationInst *state) : Op(state) {} + explicit CallOp(const Instruction *state) : Op(state) {} }; /// The "call_indirect" operation represents an indirect call to a value of @@ -167,7 +167,7 @@ public: protected: friend class Instruction; - explicit CallIndirectOp(const OperationInst *state) : Op(state) {} + explicit CallIndirectOp(const Instruction *state) : Op(state) {} }; /// The predicate indicates the type of the comparison to perform: @@ -234,7 +234,7 @@ public: private: friend class Instruction; - explicit CmpIOp(const OperationInst *state) : Op(state) {} + explicit CmpIOp(const Instruction *state) : Op(state) {} }; /// The "dealloc" operation frees the region of memory referenced by a memref @@ -266,7 +266,7 @@ public: private: friend class Instruction; - explicit DeallocOp(const OperationInst *state) : Op(state) {} + explicit DeallocOp(const Instruction *state) : Op(state) {} }; /// The "dim" operation takes a memref or tensor operand and returns an @@ -298,7 +298,7 @@ public: private: friend class Instruction; - explicit DimOp(const OperationInst *state) : Op(state) {} + explicit DimOp(const Instruction *state) : Op(state) {} }; // DmaStartOp starts a non-blocking DMA operation that transfers data from a @@ -355,7 +355,7 @@ public: return getSrcMemRef()->getType().cast<MemRefType>().getRank(); } // Returns the source memerf indices for this DMA operation. - llvm::iterator_range<OperationInst::const_operand_iterator> + llvm::iterator_range<Instruction::const_operand_iterator> getSrcIndices() const { return {getInstruction()->operand_begin() + 1, getInstruction()->operand_begin() + 1 + getSrcMemRefRank()}; @@ -377,7 +377,7 @@ public: } // Returns the destination memref indices for this DMA operation. - llvm::iterator_range<OperationInst::const_operand_iterator> + llvm::iterator_range<Instruction::const_operand_iterator> getDstIndices() const { return {getInstruction()->operand_begin() + 1 + getSrcMemRefRank() + 1, getInstruction()->operand_begin() + 1 + getSrcMemRefRank() + 1 + @@ -399,7 +399,7 @@ public: } // Returns the tag memref index for this DMA operation. - llvm::iterator_range<OperationInst::const_operand_iterator> + llvm::iterator_range<Instruction::const_operand_iterator> getTagIndices() const { unsigned tagIndexStartPos = 1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1 + 1; @@ -460,7 +460,7 @@ public: protected: friend class Instruction; - explicit DmaStartOp(const OperationInst *state) : Op(state) {} + explicit DmaStartOp(const Instruction *state) : Op(state) {} }; // DmaWaitOp blocks until the completion of a DMA operation associated with the @@ -489,7 +489,7 @@ public: Value *getTagMemRef() { return getOperand(0); } // Returns the tag memref index for this DMA operation. - llvm::iterator_range<OperationInst::const_operand_iterator> + llvm::iterator_range<Instruction::const_operand_iterator> getTagIndices() const { return {getInstruction()->operand_begin() + 1, getInstruction()->operand_begin() + 1 + getTagMemRefRank()}; @@ -512,7 +512,7 @@ public: protected: friend class Instruction; - explicit DmaWaitOp(const OperationInst *state) : Op(state) {} + explicit DmaWaitOp(const Instruction *state) : Op(state) {} }; /// The "extract_element" op reads a tensor or vector and returns one element @@ -536,13 +536,12 @@ public: Value *getAggregate() { return getOperand(0); } const Value *getAggregate() const { return getOperand(0); } - llvm::iterator_range<OperationInst::operand_iterator> getIndices() { + llvm::iterator_range<Instruction::operand_iterator> getIndices() { return {getInstruction()->operand_begin() + 1, getInstruction()->operand_end()}; } - llvm::iterator_range<OperationInst::const_operand_iterator> - getIndices() const { + llvm::iterator_range<Instruction::const_operand_iterator> getIndices() const { return {getInstruction()->operand_begin() + 1, getInstruction()->operand_end()}; } @@ -558,7 +557,7 @@ public: private: friend class Instruction; - explicit ExtractElementOp(const OperationInst *state) : Op(state) {} + explicit ExtractElementOp(const Instruction *state) : Op(state) {} }; /// The "load" op reads an element from a memref specified by an index list. The @@ -583,13 +582,12 @@ public: return getMemRef()->getType().cast<MemRefType>(); } - llvm::iterator_range<OperationInst::operand_iterator> getIndices() { + llvm::iterator_range<Instruction::operand_iterator> getIndices() { return {getInstruction()->operand_begin() + 1, getInstruction()->operand_end()}; } - llvm::iterator_range<OperationInst::const_operand_iterator> - getIndices() const { + llvm::iterator_range<Instruction::const_operand_iterator> getIndices() const { return {getInstruction()->operand_begin() + 1, getInstruction()->operand_end()}; } @@ -604,7 +602,7 @@ public: private: friend class Instruction; - explicit LoadOp(const OperationInst *state) : Op(state) {} + explicit LoadOp(const Instruction *state) : Op(state) {} }; /// The "memref_cast" operation converts a memref from one type to an equivalent @@ -635,7 +633,7 @@ public: private: friend class Instruction; - explicit MemRefCastOp(const OperationInst *state) : CastOp(state) {} + explicit MemRefCastOp(const Instruction *state) : CastOp(state) {} }; /// The "select" operation chooses one value based on a binary condition @@ -671,7 +669,7 @@ public: private: friend class Instruction; - explicit SelectOp(const OperationInst *state) : Op(state) {} + explicit SelectOp(const Instruction *state) : Op(state) {} }; /// The "store" op writes an element to a memref specified by an index list. @@ -702,13 +700,12 @@ public: return getMemRef()->getType().cast<MemRefType>(); } - llvm::iterator_range<OperationInst::operand_iterator> getIndices() { + llvm::iterator_range<Instruction::operand_iterator> getIndices() { return {getInstruction()->operand_begin() + 2, getInstruction()->operand_end()}; } - llvm::iterator_range<OperationInst::const_operand_iterator> - getIndices() const { + llvm::iterator_range<Instruction::const_operand_iterator> getIndices() const { return {getInstruction()->operand_begin() + 2, getInstruction()->operand_end()}; } @@ -724,7 +721,7 @@ public: private: friend class Instruction; - explicit StoreOp(const OperationInst *state) : Op(state) {} + explicit StoreOp(const Instruction *state) : Op(state) {} }; /// The "tensor_cast" operation converts a tensor from one type to an equivalent @@ -750,7 +747,7 @@ public: private: friend class Instruction; - explicit TensorCastOp(const OperationInst *state) : CastOp(state) {} + explicit TensorCastOp(const Instruction *state) : CastOp(state) {} }; } // end namespace mlir diff --git a/mlir/include/mlir/SuperVectorOps/SuperVectorOps.h b/mlir/include/mlir/SuperVectorOps/SuperVectorOps.h index 4e8f3f7f328..6bce67605a9 100644 --- a/mlir/include/mlir/SuperVectorOps/SuperVectorOps.h +++ b/mlir/include/mlir/SuperVectorOps/SuperVectorOps.h @@ -113,9 +113,8 @@ public: MemRefType getMemRefType() const { return getMemRef()->getType().cast<MemRefType>(); } - llvm::iterator_range<OperationInst::operand_iterator> getIndices(); - llvm::iterator_range<OperationInst::const_operand_iterator> - getIndices() const; + llvm::iterator_range<Instruction::operand_iterator> getIndices(); + llvm::iterator_range<Instruction::const_operand_iterator> getIndices() const; Optional<Value *> getPaddingValue(); Optional<const Value *> getPaddingValue() const; AffineMap getPermutationMap() const; @@ -126,7 +125,7 @@ public: private: friend class Instruction; - explicit VectorTransferReadOp(const OperationInst *state) : Op(state) {} + explicit VectorTransferReadOp(const Instruction *state) : Op(state) {} }; /// VectorTransferWriteOp performs a blocking write from a super-vector to @@ -182,9 +181,8 @@ public: MemRefType getMemRefType() const { return getMemRef()->getType().cast<MemRefType>(); } - llvm::iterator_range<OperationInst::operand_iterator> getIndices(); - llvm::iterator_range<OperationInst::const_operand_iterator> - getIndices() const; + llvm::iterator_range<Instruction::operand_iterator> getIndices(); + llvm::iterator_range<Instruction::const_operand_iterator> getIndices() const; AffineMap getPermutationMap() const; static bool parse(OpAsmParser *parser, OperationState *result); @@ -193,7 +191,7 @@ public: private: friend class Instruction; - explicit VectorTransferWriteOp(const OperationInst *state) : Op(state) {} + explicit VectorTransferWriteOp(const Instruction *state) : Op(state) {} }; /// VectorTypeCastOp performs a conversion from a memref with scalar element to @@ -217,7 +215,7 @@ public: private: friend class Instruction; - explicit VectorTypeCastOp(const OperationInst *state) : Op(state) {} + explicit VectorTypeCastOp(const Instruction *state) : Op(state) {} }; } // end namespace mlir diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 0d1bdace9a5..348201b3976 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -33,7 +33,6 @@ class Block; class FuncBuilder; class Instruction; class MLIRContext; -using OperationInst = Instruction; class Type; class Value; @@ -43,7 +42,7 @@ class FunctionConversion; } /// Base class for the dialect op conversion patterns. Specific conversions -/// must derive this class and implement `PatternMatch match(OperationInst *)` +/// must derive this class and implement `PatternMatch match(Instruction *)` /// defined in `Pattern` and at least one of `rewrite` and `rewriteTerminator`. // // TODO(zinenko): this should eventually converge with RewritePattern. So far, @@ -67,7 +66,7 @@ public: /// DialectOpConversion ever needs to replace an operation that does not have /// successors. This function should not fail. If some specific cases of the /// operation are not supported, these cases should not be matched. - virtual SmallVector<Value *, 4> rewrite(OperationInst *op, + virtual SmallVector<Value *, 4> rewrite(Instruction *op, ArrayRef<Value *> operands, FuncBuilder &rewriter) const { llvm_unreachable("unimplemented rewrite, did you mean rewriteTerminator?"); @@ -85,7 +84,7 @@ public: /// successors. This function should not fail the pass. If some specific /// cases of the operation are not supported, these cases should not be /// matched. - virtual void rewriteTerminator(OperationInst *op, + virtual void rewriteTerminator(Instruction *op, ArrayRef<Value *> properOperands, ArrayRef<Block *> destinations, ArrayRef<ArrayRef<Value *>> operands, diff --git a/mlir/include/mlir/Transforms/MLPatternLoweringPass.h b/mlir/include/mlir/Transforms/MLPatternLoweringPass.h index 00c6577240c..c6f810a215c 100644 --- a/mlir/include/mlir/Transforms/MLPatternLoweringPass.h +++ b/mlir/include/mlir/Transforms/MLPatternLoweringPass.h @@ -37,7 +37,7 @@ public: FuncBuilder *getBuilder() { return builder; } - OperationInst *createOperation(const OperationState &state) override { + Instruction *createOperation(const OperationState &state) override { auto *result = builder->createOperation(state); return result; } @@ -66,7 +66,7 @@ public: /// must override). It will be passed the function-wise state, common to all /// matches, and the state returned by the `match` call, if any. The subclass /// must use `rewriter` to modify the function. - virtual void rewriteOpInst(OperationInst *op, + virtual void rewriteOpInst(Instruction *op, MLFuncGlobalLoweringState *funcWiseState, std::unique_ptr<PatternState> opState, MLFuncLoweringRewriter *rewriter) const = 0; @@ -143,10 +143,10 @@ PassResult MLPatternLoweringPass<Patterns...>::runOnFunction(Function *f) { FuncBuilder builder(f); MLFuncLoweringRewriter rewriter(&builder); - llvm::SmallVector<OperationInst *, 16> ops; - f->walkOps([&ops](OperationInst *inst) { ops.push_back(inst); }); + llvm::SmallVector<Instruction *, 16> ops; + f->walkOps([&ops](Instruction *inst) { ops.push_back(inst); }); - for (OperationInst *inst : ops) { + for (Instruction *inst : ops) { for (const auto &pattern : patterns) { rewriter.getBuilder()->setInsertionPoint(inst); auto matchResult = pattern->match(inst); diff --git a/mlir/include/mlir/Transforms/Utils.h b/mlir/include/mlir/Transforms/Utils.h index c43652b316b..784e68a5ab3 100644 --- a/mlir/include/mlir/Transforms/Utils.h +++ b/mlir/include/mlir/Transforms/Utils.h @@ -75,11 +75,10 @@ bool replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef, /// these will also be collected into a single (multi-result) affine apply op. /// The final results of the composed AffineApplyOp are returned in output /// parameter 'results'. Returns the affine apply op created. -OperationInst * -createComposedAffineApplyOp(FuncBuilder *builder, Location loc, - ArrayRef<Value *> operands, - ArrayRef<OperationInst *> affineApplyOps, - SmallVectorImpl<Value *> *results); +Instruction *createComposedAffineApplyOp(FuncBuilder *builder, Location loc, + ArrayRef<Value *> operands, + ArrayRef<Instruction *> affineApplyOps, + SmallVectorImpl<Value *> *results); /// Given an operation instruction, inserts one or more single result affine /// apply operations, results of which are exclusively used by this operation @@ -110,7 +109,7 @@ createComposedAffineApplyOp(FuncBuilder *builder, Location loc, /// uses other than those in this opInst. The method otherwise returns the list /// of affine_apply operations created in output argument `sliceOps`. void createAffineComputationSlice( - OperationInst *opInst, SmallVectorImpl<OpPointer<AffineApplyOp>> *sliceOps); + Instruction *opInst, SmallVectorImpl<OpPointer<AffineApplyOp>> *sliceOps); /// Folds the lower and upper bounds of a 'for' inst to constants if possible. /// Returns false if the folding happens for at least one bound, true otherwise. @@ -119,7 +118,7 @@ bool constantFoldBounds(OpPointer<AffineForOp> forInst); /// Replaces (potentially nested) function attributes in the operation "op" /// with those specified in "remappingTable". void remapFunctionAttrs( - OperationInst &op, const DenseMap<Attribute, FunctionAttr> &remappingTable); + Instruction &op, const DenseMap<Attribute, FunctionAttr> &remappingTable); /// Replaces (potentially nested) function attributes all operations of the /// Function "fn" with those specified in "remappingTable". diff --git a/mlir/lib/AffineOps/AffineOps.cpp b/mlir/lib/AffineOps/AffineOps.cpp index f1693c8e449..682a8e4f1ed 100644 --- a/mlir/lib/AffineOps/AffineOps.cpp +++ b/mlir/lib/AffineOps/AffineOps.cpp @@ -410,13 +410,13 @@ bool AffineForOp::matchingBoundOperandList() const { return true; } -void AffineForOp::walkOps(std::function<void(OperationInst *)> callback) { +void AffineForOp::walkOps(std::function<void(Instruction *)> callback) { struct Walker : public InstWalker<Walker> { - std::function<void(OperationInst *)> const &callback; - Walker(std::function<void(OperationInst *)> const &callback) + std::function<void(Instruction *)> const &callback; + Walker(std::function<void(Instruction *)> const &callback) : callback(callback) {} - void visitOperationInst(OperationInst *opInst) { callback(opInst); } + void visitOperationInst(Instruction *opInst) { callback(opInst); } }; Walker w(callback); @@ -424,13 +424,13 @@ void AffineForOp::walkOps(std::function<void(OperationInst *)> callback) { } void AffineForOp::walkOpsPostOrder( - std::function<void(OperationInst *)> callback) { + std::function<void(Instruction *)> callback) { struct Walker : public InstWalker<Walker> { - std::function<void(OperationInst *)> const &callback; - Walker(std::function<void(OperationInst *)> const &callback) + std::function<void(Instruction *)> const &callback; + Walker(std::function<void(Instruction *)> const &callback) : callback(callback) {} - void visitOperationInst(OperationInst *opInst) { callback(opInst); } + void visitOperationInst(Instruction *opInst) { callback(opInst); } }; Walker v(callback); @@ -454,7 +454,7 @@ OpPointer<AffineForOp> mlir::getForInductionVarOwner(Value *val) { auto *containingInst = ivArg->getOwner()->getParent()->getContainingInst(); if (!containingInst) return OpPointer<AffineForOp>(); - return cast<OperationInst>(containingInst)->dyn_cast<AffineForOp>(); + return containingInst->dyn_cast<AffineForOp>(); } ConstOpPointer<AffineForOp> mlir::getForInductionVarOwner(const Value *val) { auto nonConstOwner = getForInductionVarOwner(const_cast<Value *>(val)); diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp index 8b64f498ce8..0936798d71a 100644 --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -498,14 +498,14 @@ bool mlir::getFlattenedAffineExprs( localVarCst); } -/// Returns the sequence of AffineApplyOp OperationInsts operation in +/// Returns the sequence of AffineApplyOp Instructions operation in /// 'affineApplyOps', which are reachable via a search starting from 'operands', /// and ending at operands which are not defined by AffineApplyOps. // TODO(andydavis) Add a method to AffineApplyOp which forward substitutes // the AffineApplyOp into any user AffineApplyOps. void mlir::getReachableAffineApplyOps( ArrayRef<Value *> operands, - SmallVectorImpl<OperationInst *> &affineApplyOps) { + SmallVectorImpl<Instruction *> &affineApplyOps) { struct State { // The ssa value for this node in the DFS traversal. Value *value; @@ -521,7 +521,7 @@ void mlir::getReachableAffineApplyOps( State &state = worklist.back(); auto *opInst = state.value->getDefiningInst(); // Note: getDefiningInst will return nullptr if the operand is not an - // OperationInst (i.e. AffineForOp), which is a terminator for the search. + // Instruction (i.e. AffineForOp), which is a terminator for the search. if (opInst == nullptr || !opInst->isa<AffineApplyOp>()) { worklist.pop_back(); continue; diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp index e3055c5530d..4ded1bfc400 100644 --- a/mlir/lib/Analysis/LoopAnalysis.cpp +++ b/mlir/lib/Analysis/LoopAnalysis.cpp @@ -129,7 +129,7 @@ uint64_t mlir::getLargestDivisorOfTripCount(ConstOpPointer<AffineForOp> forOp) { bool mlir::isAccessInvariant(const Value &iv, const Value &index) { assert(isForInductionVar(&iv) && "iv must be a AffineForOp"); assert(index.getType().isa<IndexType>() && "index must be of IndexType"); - SmallVector<OperationInst *, 4> affineApplyOps; + SmallVector<Instruction *, 4> affineApplyOps; getReachableAffineApplyOps({const_cast<Value *>(&index)}, affineApplyOps); if (affineApplyOps.empty()) { @@ -226,17 +226,15 @@ static bool isVectorElement(LoadOrStoreOpPointer memoryOp) { } static bool isVectorTransferReadOrWrite(const Instruction &inst) { - const auto *opInst = cast<OperationInst>(&inst); - return opInst->isa<VectorTransferReadOp>() || - opInst->isa<VectorTransferWriteOp>(); + return inst.isa<VectorTransferReadOp>() || inst.isa<VectorTransferWriteOp>(); } using VectorizableInstFun = - std::function<bool(ConstOpPointer<AffineForOp>, const OperationInst &)>; + std::function<bool(ConstOpPointer<AffineForOp>, const Instruction &)>; static bool isVectorizableLoopWithCond(ConstOpPointer<AffineForOp> loop, VectorizableInstFun isVectorizableInst) { - auto *forInst = const_cast<OperationInst *>(loop->getInstruction()); + auto *forInst = const_cast<Instruction *>(loop->getInstruction()); if (!matcher::isParallelLoop(*forInst) && !matcher::isReductionLoop(*forInst)) { return false; @@ -252,9 +250,8 @@ static bool isVectorizableLoopWithCond(ConstOpPointer<AffineForOp> loop, // No vectorization across unknown regions. auto regions = matcher::Op([](const Instruction &inst) -> bool { - auto &opInst = cast<OperationInst>(inst); - return opInst.getNumBlockLists() != 0 && - !(opInst.isa<AffineIfOp>() || opInst.isa<AffineForOp>()); + return inst.getNumBlockLists() != 0 && + !(inst.isa<AffineIfOp>() || inst.isa<AffineForOp>()); }); SmallVector<NestedMatch, 8> regionsMatched; regions.match(forInst, ®ionsMatched); @@ -273,7 +270,7 @@ static bool isVectorizableLoopWithCond(ConstOpPointer<AffineForOp> loop, SmallVector<NestedMatch, 8> loadAndStoresMatched; loadAndStores.match(forInst, &loadAndStoresMatched); for (auto ls : loadAndStoresMatched) { - auto *op = cast<OperationInst>(ls.getMatchedInstruction()); + auto *op = ls.getMatchedInstruction(); auto load = op->dyn_cast<LoadOp>(); auto store = op->dyn_cast<StoreOp>(); // Only scalar types are considered vectorizable, all load/store must be @@ -293,7 +290,7 @@ static bool isVectorizableLoopWithCond(ConstOpPointer<AffineForOp> loop, bool mlir::isVectorizableLoopAlongFastestVaryingMemRefDim( ConstOpPointer<AffineForOp> loop, unsigned fastestVaryingDim) { VectorizableInstFun fun([fastestVaryingDim](ConstOpPointer<AffineForOp> loop, - const OperationInst &op) { + const Instruction &op) { auto load = op.dyn_cast<LoadOp>(); auto store = op.dyn_cast<StoreOp>(); return load ? isContiguousAccess(*loop->getInductionVar(), *load, @@ -307,7 +304,7 @@ bool mlir::isVectorizableLoopAlongFastestVaryingMemRefDim( bool mlir::isVectorizableLoop(ConstOpPointer<AffineForOp> loop) { VectorizableInstFun fun( // TODO: implement me - [](ConstOpPointer<AffineForOp> loop, const OperationInst &op) { + [](ConstOpPointer<AffineForOp> loop, const Instruction &op) { return true; }); return isVectorizableLoopWithCond(loop, fun); @@ -324,20 +321,16 @@ bool mlir::isInstwiseShiftValid(ConstOpPointer<AffineForOp> forOp, assert(shifts.size() == forBody->getInstructions().size()); unsigned s = 0; for (const auto &inst : *forBody) { - // A for or if inst does not produce any def/results (that are used - // outside). - if (const auto *opInst = dyn_cast<OperationInst>(&inst)) { - for (unsigned i = 0, e = opInst->getNumResults(); i < e; ++i) { - const Value *result = opInst->getResult(i); - for (const InstOperand &use : result->getUses()) { - // If an ancestor instruction doesn't lie in the block of forOp, - // there is no shift to check. This is a naive way. If performance - // becomes an issue, a map can be used to store 'shifts' - to look up - // the shift for a instruction in constant time. - if (auto *ancInst = forBody->findAncestorInstInBlock(*use.getOwner())) - if (shifts[s] != shifts[forBody->findInstPositionInBlock(*ancInst)]) - return false; - } + for (unsigned i = 0, e = inst.getNumResults(); i < e; ++i) { + const Value *result = inst.getResult(i); + for (const InstOperand &use : result->getUses()) { + // If an ancestor instruction doesn't lie in the block of forOp, + // there is no shift to check. This is a naive way. If performance + // becomes an issue, a map can be used to store 'shifts' - to look up + // the shift for a instruction in constant time. + if (auto *ancInst = forBody->findAncestorInstInBlock(*use.getOwner())) + if (shifts[s] != shifts[forBody->findInstPositionInBlock(*ancInst)]) + return false; } } s++; diff --git a/mlir/lib/Analysis/MemRefBoundCheck.cpp b/mlir/lib/Analysis/MemRefBoundCheck.cpp index d21f2f8035b..ab22f261a3b 100644 --- a/mlir/lib/Analysis/MemRefBoundCheck.cpp +++ b/mlir/lib/Analysis/MemRefBoundCheck.cpp @@ -43,7 +43,7 @@ struct MemRefBoundCheck : public FunctionPass, InstWalker<MemRefBoundCheck> { PassResult runOnFunction(Function *f) override; - void visitOperationInst(OperationInst *opInst); + void visitInstruction(Instruction *opInst); static char passID; }; @@ -56,7 +56,7 @@ FunctionPass *mlir::createMemRefBoundCheckPass() { return new MemRefBoundCheck(); } -void MemRefBoundCheck::visitOperationInst(OperationInst *opInst) { +void MemRefBoundCheck::visitInstruction(Instruction *opInst) { if (auto loadOp = opInst->dyn_cast<LoadOp>()) { boundCheckLoadOrStoreOp(loadOp); } else if (auto storeOp = opInst->dyn_cast<StoreOp>()) { diff --git a/mlir/lib/Analysis/MemRefDependenceCheck.cpp b/mlir/lib/Analysis/MemRefDependenceCheck.cpp index 043d62d0cc9..b2549910a17 100644 --- a/mlir/lib/Analysis/MemRefDependenceCheck.cpp +++ b/mlir/lib/Analysis/MemRefDependenceCheck.cpp @@ -40,13 +40,13 @@ namespace { /// Checks dependences between all pairs of memref accesses in a Function. struct MemRefDependenceCheck : public FunctionPass, InstWalker<MemRefDependenceCheck> { - SmallVector<OperationInst *, 4> loadsAndStores; + SmallVector<Instruction *, 4> loadsAndStores; explicit MemRefDependenceCheck() : FunctionPass(&MemRefDependenceCheck::passID) {} PassResult runOnFunction(Function *f) override; - void visitOperationInst(OperationInst *opInst) { + void visitOperationInst(Instruction *opInst) { if (opInst->isa<LoadOp>() || opInst->isa<StoreOp>()) { loadsAndStores.push_back(opInst); } @@ -88,7 +88,7 @@ getDirectionVectorStr(bool ret, unsigned numCommonLoops, unsigned loopNestDepth, // "source" access and all subsequent "destination" accesses in // 'loadsAndStores'. Emits the result of the dependence check as a note with // the source access. -static void checkDependences(ArrayRef<OperationInst *> loadsAndStores) { +static void checkDependences(ArrayRef<Instruction *> loadsAndStores) { for (unsigned i = 0, e = loadsAndStores.size(); i < e; ++i) { auto *srcOpInst = loadsAndStores[i]; MemRefAccess srcAccess(srcOpInst); diff --git a/mlir/lib/Analysis/NestedMatcher.cpp b/mlir/lib/Analysis/NestedMatcher.cpp index 214b4ce403c..ec1b60ee437 100644 --- a/mlir/lib/Analysis/NestedMatcher.cpp +++ b/mlir/lib/Analysis/NestedMatcher.cpp @@ -48,10 +48,9 @@ llvm::BumpPtrAllocator *&NestedPattern::allocator() { return allocator; } -NestedPattern::NestedPattern(Instruction::Kind k, - ArrayRef<NestedPattern> nested, +NestedPattern::NestedPattern(ArrayRef<NestedPattern> nested, FilterFunctionType filter) - : kind(k), nestedPatterns(), filter(filter), skip(nullptr) { + : nestedPatterns(), filter(filter), skip(nullptr) { if (!nested.empty()) { auto *newNested = allocator()->Allocate<NestedPattern>(nested.size()); std::uninitialized_copy(nested.begin(), nested.end(), newNested); @@ -85,10 +84,6 @@ void NestedPattern::matchOne(Instruction *inst, if (skip == inst) { return; } - // Structural filter - if (inst->getKind() != kind) { - return; - } // Local custom filter function if (!filter(*inst)) { return; @@ -116,74 +111,68 @@ void NestedPattern::matchOne(Instruction *inst, } static bool isAffineForOp(const Instruction &inst) { - return cast<OperationInst>(inst).isa<AffineForOp>(); + return inst.isa<AffineForOp>(); } static bool isAffineIfOp(const Instruction &inst) { - return isa<OperationInst>(inst) && - cast<OperationInst>(inst).isa<AffineIfOp>(); + return inst.isa<AffineIfOp>(); } namespace mlir { namespace matcher { NestedPattern Op(FilterFunctionType filter) { - return NestedPattern(Instruction::Kind::OperationInst, {}, filter); + return NestedPattern({}, filter); } NestedPattern If(NestedPattern child) { - return NestedPattern(Instruction::Kind::OperationInst, child, isAffineIfOp); + return NestedPattern(child, isAffineIfOp); } NestedPattern If(FilterFunctionType filter, NestedPattern child) { - return NestedPattern(Instruction::Kind::OperationInst, child, - [filter](const Instruction &inst) { - return isAffineIfOp(inst) && filter(inst); - }); + return NestedPattern(child, [filter](const Instruction &inst) { + return isAffineIfOp(inst) && filter(inst); + }); } NestedPattern If(ArrayRef<NestedPattern> nested) { - return NestedPattern(Instruction::Kind::OperationInst, nested, isAffineIfOp); + return NestedPattern(nested, isAffineIfOp); } NestedPattern If(FilterFunctionType filter, ArrayRef<NestedPattern> nested) { - return NestedPattern(Instruction::Kind::OperationInst, nested, - [filter](const Instruction &inst) { - return isAffineIfOp(inst) && filter(inst); - }); + return NestedPattern(nested, [filter](const Instruction &inst) { + return isAffineIfOp(inst) && filter(inst); + }); } NestedPattern For(NestedPattern child) { - return NestedPattern(Instruction::Kind::OperationInst, child, isAffineForOp); + return NestedPattern(child, isAffineForOp); } NestedPattern For(FilterFunctionType filter, NestedPattern child) { - return NestedPattern(Instruction::Kind::OperationInst, child, - [=](const Instruction &inst) { - return isAffineForOp(inst) && filter(inst); - }); + return NestedPattern(child, [=](const Instruction &inst) { + return isAffineForOp(inst) && filter(inst); + }); } NestedPattern For(ArrayRef<NestedPattern> nested) { - return NestedPattern(Instruction::Kind::OperationInst, nested, isAffineForOp); + return NestedPattern(nested, isAffineForOp); } NestedPattern For(FilterFunctionType filter, ArrayRef<NestedPattern> nested) { - return NestedPattern(Instruction::Kind::OperationInst, nested, - [=](const Instruction &inst) { - return isAffineForOp(inst) && filter(inst); - }); + return NestedPattern(nested, [=](const Instruction &inst) { + return isAffineForOp(inst) && filter(inst); + }); } // TODO(ntv): parallel annotation on loops. bool isParallelLoop(const Instruction &inst) { - auto loop = cast<OperationInst>(inst).cast<AffineForOp>(); + auto loop = inst.cast<AffineForOp>(); return loop || true; // loop->isParallel(); }; // TODO(ntv): reduction annotation on loops. bool isReductionLoop(const Instruction &inst) { - auto loop = cast<OperationInst>(inst).cast<AffineForOp>(); + auto loop = inst.cast<AffineForOp>(); return loop || true; // loop->isReduction(); }; bool isLoadOrStore(const Instruction &inst) { - const auto *opInst = dyn_cast<OperationInst>(&inst); - return opInst && (opInst->isa<LoadOp>() || opInst->isa<StoreOp>()); + return inst.isa<LoadOp>() || inst.isa<StoreOp>(); }; } // end namespace matcher diff --git a/mlir/lib/Analysis/OpStats.cpp b/mlir/lib/Analysis/OpStats.cpp index 90c5c5fde0e..742c0baa96b 100644 --- a/mlir/lib/Analysis/OpStats.cpp +++ b/mlir/lib/Analysis/OpStats.cpp @@ -35,7 +35,7 @@ struct PrintOpStatsPass : public ModulePass, InstWalker<PrintOpStatsPass> { PassResult runOnModule(Module *m) override; // Updates the operation statistics for the given instruction. - void visitOperationInst(OperationInst *inst); + void visitInstruction(Instruction *inst); // Print summary of op stats. void printSummary(); @@ -58,7 +58,7 @@ PassResult PrintOpStatsPass::runOnModule(Module *m) { return success(); } -void PrintOpStatsPass::visitOperationInst(OperationInst *inst) { +void PrintOpStatsPass::visitInstruction(Instruction *inst) { ++opCount[inst->getName().getStringRef()]; } diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp index 5211893f055..877a0f2f364 100644 --- a/mlir/lib/Analysis/SliceAnalysis.cpp +++ b/mlir/lib/Analysis/SliceAnalysis.cpp @@ -53,8 +53,7 @@ void mlir::getForwardSlice(Instruction *inst, return; } - auto *opInst = cast<OperationInst>(inst); - if (auto forOp = opInst->dyn_cast<AffineForOp>()) { + if (auto forOp = inst->dyn_cast<AffineForOp>()) { for (auto &u : forOp->getInductionVar()->getUses()) { auto *ownerInst = u.getOwner(); if (forwardSlice->count(ownerInst) == 0) { @@ -63,9 +62,9 @@ void mlir::getForwardSlice(Instruction *inst, } } } else { - assert(opInst->getNumResults() <= 1 && "NYI: multiple results"); - if (opInst->getNumResults() > 0) { - for (auto &u : opInst->getResult(0)->getUses()) { + assert(inst->getNumResults() <= 1 && "NYI: multiple results"); + if (inst->getNumResults() > 0) { + for (auto &u : inst->getResult(0)->getUses()) { auto *ownerInst = u.getOwner(); if (forwardSlice->count(ownerInst) == 0) { getForwardSlice(ownerInst, forwardSlice, filter, @@ -156,10 +155,9 @@ struct DFSState { } // namespace static void DFSPostorder(Instruction *current, DFSState *state) { - auto *opInst = cast<OperationInst>(current); - assert(opInst->getNumResults() <= 1 && "NYI: multi-result"); - if (opInst->getNumResults() > 0) { - for (auto &u : opInst->getResult(0)->getUses()) { + assert(current->getNumResults() <= 1 && "NYI: multi-result"); + if (current->getNumResults() > 0) { + for (auto &u : current->getResult(0)->getUses()) { auto *inst = u.getOwner(); DFSPostorder(inst, state); } diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index 4b8afd9a620..24361ac621f 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -43,10 +43,8 @@ void mlir::getLoopIVs(const Instruction &inst, OpPointer<AffineForOp> currAffineForOp; // Traverse up the hierarchy collecing all 'for' instruction while skipping // over 'if' instructions. - while (currInst && - ((currAffineForOp = - cast<OperationInst>(currInst)->dyn_cast<AffineForOp>()) || - cast<OperationInst>(currInst)->isa<AffineIfOp>())) { + while (currInst && ((currAffineForOp = currInst->dyn_cast<AffineForOp>()) || + currInst->isa<AffineIfOp>())) { if (currAffineForOp) loops->push_back(currAffineForOp); currInst = currInst->getParentInst(); @@ -124,7 +122,7 @@ bool MemRefRegion::unionBoundingBox(const MemRefRegion &other) { // // TODO(bondhugula): extend this to any other memref dereferencing ops // (dma_start, dma_wait). -bool mlir::getMemRefRegion(OperationInst *opInst, unsigned loopDepth, +bool mlir::getMemRefRegion(Instruction *opInst, unsigned loopDepth, MemRefRegion *region) { unsigned rank; SmallVector<Value *, 4> indices; @@ -279,7 +277,7 @@ bool mlir::boundCheckLoadOrStoreOp(LoadOrStoreOpPointer loadOrStoreOp, std::is_same<LoadOrStoreOpPointer, OpPointer<StoreOp>>::value, "argument should be either a LoadOp or a StoreOp"); - OperationInst *opInst = loadOrStoreOp->getInstruction(); + Instruction *opInst = loadOrStoreOp->getInstruction(); MemRefRegion region; if (!getMemRefRegion(opInst, /*loopDepth=*/0, ®ion)) return false; @@ -359,12 +357,11 @@ static Instruction *getInstAtPosition(ArrayRef<unsigned> positions, } if (level == positions.size() - 1) return &inst; - if (auto childAffineForOp = - cast<OperationInst>(inst).dyn_cast<AffineForOp>()) + if (auto childAffineForOp = inst.dyn_cast<AffineForOp>()) return getInstAtPosition(positions, level + 1, childAffineForOp->getBody()); - for (auto &blockList : cast<OperationInst>(&inst)->getBlockLists()) { + for (auto &blockList : inst.getBlockLists()) { for (auto &b : blockList) if (auto *ret = getInstAtPosition(positions, level + 1, &b)) return ret; @@ -442,7 +439,7 @@ bool mlir::getBackwardComputationSliceState(const MemRefAccess &srcAccess, // TODO(andydavis) Remove dependence on 'srcLoopDepth' here. Instead project // out loop IVs we don't care about and produce smaller slice. OpPointer<AffineForOp> mlir::insertBackwardComputationSlice( - OperationInst *srcOpInst, OperationInst *dstOpInst, unsigned dstLoopDepth, + Instruction *srcOpInst, Instruction *dstOpInst, unsigned dstLoopDepth, ComputationSliceState *sliceState) { // Get loop nest surrounding src operation. SmallVector<OpPointer<AffineForOp>, 4> srcLoopIVs; @@ -469,8 +466,7 @@ OpPointer<AffineForOp> mlir::insertBackwardComputationSlice( auto dstAffineForOp = dstLoopIVs[dstLoopDepth - 1]; FuncBuilder b(dstAffineForOp->getBody(), dstAffineForOp->getBody()->begin()); auto sliceLoopNest = - cast<OperationInst>(b.clone(*srcLoopIVs[0]->getInstruction())) - ->cast<AffineForOp>(); + b.clone(*srcLoopIVs[0]->getInstruction())->cast<AffineForOp>(); Instruction *sliceInst = getInstAtPosition(positions, /*level=*/0, sliceLoopNest->getBody()); @@ -499,7 +495,7 @@ OpPointer<AffineForOp> mlir::insertBackwardComputationSlice( // Constructs MemRefAccess populating it with the memref, its indices and // opinst from 'loadOrStoreOpInst'. -MemRefAccess::MemRefAccess(OperationInst *loadOrStoreOpInst) { +MemRefAccess::MemRefAccess(Instruction *loadOrStoreOpInst) { if (auto loadOp = loadOrStoreOpInst->dyn_cast<LoadOp>()) { memref = loadOp->getMemRef(); opInst = loadOrStoreOpInst; @@ -527,7 +523,7 @@ unsigned mlir::getNestingDepth(const Instruction &stmt) { const Instruction *currInst = &stmt; unsigned depth = 0; while ((currInst = currInst->getParentInst())) { - if (cast<OperationInst>(currInst)->isa<AffineForOp>()) + if (currInst->isa<AffineForOp>()) depth++; } return depth; @@ -585,7 +581,7 @@ mlir::getMemoryFootprintBytes(ConstOpPointer<AffineForOp> forOp, // Walk this 'for' instruction to gather all memory regions. bool error = false; - const_cast<AffineForOp &>(*forOp).walkOps([&](OperationInst *opInst) { + const_cast<AffineForOp &>(*forOp).walkOps([&](Instruction *opInst) { if (!opInst->isa<LoadOp>() && !opInst->isa<StoreOp>()) { // Neither load nor a store op. return; diff --git a/mlir/lib/Analysis/VectorAnalysis.cpp b/mlir/lib/Analysis/VectorAnalysis.cpp index 7cafc81fc0e..9985107008a 100644 --- a/mlir/lib/Analysis/VectorAnalysis.cpp +++ b/mlir/lib/Analysis/VectorAnalysis.cpp @@ -106,7 +106,7 @@ Optional<SmallVector<unsigned, 4>> mlir::shapeRatio(VectorType superVectorType, /// header file. static AffineMap makePermutationMap( MLIRContext *context, - llvm::iterator_range<OperationInst::operand_iterator> indices, + llvm::iterator_range<Instruction::operand_iterator> indices, const DenseMap<Instruction *, unsigned> &enclosingLoopToVectorDim) { using functional::makePtrDynCaster; using functional::map; @@ -116,8 +116,7 @@ static AffineMap makePermutationMap( for (auto kvp : enclosingLoopToVectorDim) { assert(kvp.second < perm.size()); auto invariants = getInvariantAccesses( - *cast<OperationInst>(kvp.first)->cast<AffineForOp>()->getInductionVar(), - unwrappedIndices); + *kvp.first->cast<AffineForOp>()->getInductionVar(), unwrappedIndices); unsigned numIndices = unwrappedIndices.size(); unsigned countInvariantIndices = 0; for (unsigned dim = 0; dim < numIndices; ++dim) { @@ -142,14 +141,13 @@ static AffineMap makePermutationMap( /// TODO(ntv): could also be implemented as a collect parents followed by a /// filter and made available outside this file. template <typename T> -static SetVector<OperationInst *> getParentsOfType(Instruction *inst) { - SetVector<OperationInst *> res; +static SetVector<Instruction *> getParentsOfType(Instruction *inst) { + SetVector<Instruction *> res; auto *current = inst; while (auto *parent = current->getParentInst()) { - if (auto typedParent = - cast<OperationInst>(parent)->template dyn_cast<T>()) { - assert(res.count(cast<OperationInst>(parent)) == 0 && "Already inserted"); - res.insert(cast<OperationInst>(parent)); + if (auto typedParent = parent->template dyn_cast<T>()) { + assert(res.count(parent) == 0 && "Already inserted"); + res.insert(parent); } current = parent; } @@ -157,12 +155,12 @@ static SetVector<OperationInst *> getParentsOfType(Instruction *inst) { } /// Returns the enclosing AffineForOp, from closest to farthest. -static SetVector<OperationInst *> getEnclosingforOps(Instruction *inst) { +static SetVector<Instruction *> getEnclosingforOps(Instruction *inst) { return getParentsOfType<AffineForOp>(inst); } AffineMap mlir::makePermutationMap( - OperationInst *opInst, + Instruction *opInst, const DenseMap<Instruction *, unsigned> &loopToVectorDim) { DenseMap<Instruction *, unsigned> enclosingLoopToVectorDim; auto enclosingLoops = getEnclosingforOps(opInst); @@ -183,7 +181,7 @@ AffineMap mlir::makePermutationMap( enclosingLoopToVectorDim); } -bool mlir::matcher::operatesOnSuperVectors(const OperationInst &opInst, +bool mlir::matcher::operatesOnSuperVectors(const Instruction &opInst, VectorType subVectorType) { // First, extract the vector type and ditinguish between: // a. ops that *must* lower a super-vector (i.e. vector_transfer_read, diff --git a/mlir/lib/Analysis/Verifier.cpp b/mlir/lib/Analysis/Verifier.cpp index 16264ad9515..390cbf69c77 100644 --- a/mlir/lib/Analysis/Verifier.cpp +++ b/mlir/lib/Analysis/Verifier.cpp @@ -49,7 +49,7 @@ namespace { /// class FuncVerifier { public: - bool failure(const Twine &message, const OperationInst &value) { + bool failure(const Twine &message, const Instruction &value) { return value.emitError(message); } @@ -60,18 +60,17 @@ public: bool failure(const Twine &message, const Block &bb) { // Take the location information for the first instruction in the block. if (!bb.empty()) - if (auto *op = dyn_cast<OperationInst>(&bb.front())) - return failure(message, *op); + return failure(message, bb.front()); // Worst case, fall back to using the function's location. return failure(message, fn); } - bool verifyAttribute(Attribute attr, const OperationInst &op); + bool verifyAttribute(Attribute attr, const Instruction &op); bool verify(); bool verifyBlock(const Block &block, bool isTopLevel); - bool verifyOperation(const OperationInst &op); + bool verifyOperation(const Instruction &op); bool verifyDominance(const Block &block); bool verifyInstDominance(const Instruction &inst); @@ -135,7 +134,7 @@ bool FuncVerifier::verify() { } // Check that function attributes are all well formed. -bool FuncVerifier::verifyAttribute(Attribute attr, const OperationInst &op) { +bool FuncVerifier::verifyAttribute(Attribute attr, const Instruction &op) { if (!attr.isOrContainsFunction()) return false; @@ -168,14 +167,9 @@ bool FuncVerifier::verifyBlock(const Block &block, bool isTopLevel) { return failure("block argument not owned by block", block); } - for (auto &inst : block) { - switch (inst.getKind()) { - case Instruction::Kind::OperationInst: - if (verifyOperation(cast<OperationInst>(inst))) - return true; - break; - } - } + for (auto &inst : block) + if (verifyOperation(inst)) + return true; // If this block is at the function level, then verify that it has a // terminator. @@ -199,7 +193,7 @@ bool FuncVerifier::verifyBlock(const Block &block, bool isTopLevel) { } /// Check the invariants of the specified operation. -bool FuncVerifier::verifyOperation(const OperationInst &op) { +bool FuncVerifier::verifyOperation(const Instruction &op) { if (op.getFunction() != &fn) return failure("operation in the wrong function", op); @@ -240,19 +234,12 @@ bool FuncVerifier::verifyDominance(const Block &block) { // Check that all operands on the instruction are ok. if (verifyInstDominance(inst)) return true; - - switch (inst.getKind()) { - case Instruction::Kind::OperationInst: { - auto &opInst = cast<OperationInst>(inst); - if (verifyOperation(opInst)) - return true; - for (auto &blockList : opInst.getBlockLists()) - for (auto &block : blockList) - if (verifyDominance(block)) - return true; - break; - } - } + if (verifyOperation(inst)) + return true; + for (auto &blockList : inst.getBlockLists()) + for (auto &block : blockList) + if (verifyDominance(block)) + return true; } return false; } diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 2ffdb19ea63..36a4b8e3b5e 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -132,7 +132,6 @@ private: // Visit functions. void visitInstruction(const Instruction *inst); - void visitOperationInst(const OperationInst *opInst); void visitType(Type type); void visitAttribute(Attribute attr); @@ -183,25 +182,18 @@ void ModuleState::visitAttribute(Attribute attr) { } } -void ModuleState::visitOperationInst(const OperationInst *op) { +void ModuleState::visitInstruction(const Instruction *inst) { // Visit all the types used in the operation. - for (auto *operand : op->getOperands()) + for (auto *operand : inst->getOperands()) visitType(operand->getType()); - for (auto *result : op->getResults()) + for (auto *result : inst->getResults()) visitType(result->getType()); // Visit each of the attributes. - for (auto elt : op->getAttrs()) + for (auto elt : inst->getAttrs()) visitAttribute(elt.second); } -void ModuleState::visitInstruction(const Instruction *inst) { - switch (inst->getKind()) { - case Instruction::Kind::OperationInst: - return visitOperationInst(cast<OperationInst>(inst)); - } -} - // Utility to generate a function to register a symbol alias. template <typename SymbolsInModuleSetTy, typename SymbolTy> static void registerSymbolAlias(StringRef name, SymbolTy sym, @@ -1045,8 +1037,8 @@ public: void print(const Instruction *inst); void print(const Block *block, bool printBlockArgs = true); - void printOperation(const OperationInst *op); - void printGenericOp(const OperationInst *op); + void printOperation(const Instruction *op); + void printGenericOp(const Instruction *op); // Implement OpAsmPrinter. raw_ostream &getStream() const { return os; } @@ -1086,7 +1078,7 @@ public: return it != blockIDs.end() ? it->second : ~0U; } - void printSuccessorAndUseList(const OperationInst *term, + void printSuccessorAndUseList(const Instruction *term, unsigned index) override; /// Print a block list. @@ -1162,17 +1154,11 @@ void FunctionPrinter::numberValuesInBlock(const Block &block) { for (auto &inst : block) { // We number instruction that have results, and we only number the first // result. - switch (inst.getKind()) { - case Instruction::Kind::OperationInst: { - auto *opInst = cast<OperationInst>(&inst); - if (opInst->getNumResults() != 0) - numberValueID(opInst->getResult(0)); - for (auto &blockList : opInst->getBlockLists()) - for (const auto &block : blockList) - numberValuesInBlock(block); - break; - } - } + if (inst.getNumResults() != 0) + numberValueID(inst.getResult(0)); + for (auto &blockList : inst.getBlockLists()) + for (const auto &block : blockList) + numberValuesInBlock(block); } } @@ -1408,7 +1394,7 @@ void FunctionPrinter::printValueID(const Value *value, os << '#' << resultNo; } -void FunctionPrinter::printOperation(const OperationInst *op) { +void FunctionPrinter::printOperation(const Instruction *op) { if (op->getNumResults()) { printValueID(op->getResult(0), /*printResultNo=*/false); os << " = "; @@ -1425,7 +1411,7 @@ void FunctionPrinter::printOperation(const OperationInst *op) { printGenericOp(op); } -void FunctionPrinter::printGenericOp(const OperationInst *op) { +void FunctionPrinter::printGenericOp(const Instruction *op) { os << '"'; printEscapedString(op->getName().getStringRef(), os); os << "\"("; @@ -1478,7 +1464,7 @@ void FunctionPrinter::printGenericOp(const OperationInst *op) { printBlockList(blockList, /*printEntryBlockArgs=*/true); } -void FunctionPrinter::printSuccessorAndUseList(const OperationInst *term, +void FunctionPrinter::printSuccessorAndUseList(const Instruction *term, unsigned index) { printBlockName(term->getSuccessor(index)); diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index 94df36c94f6..a756e90a759 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -95,19 +95,19 @@ void OpState::emitNote(const Twine &message) const { // Op Trait implementations //===----------------------------------------------------------------------===// -bool OpTrait::impl::verifyZeroOperands(const OperationInst *op) { +bool OpTrait::impl::verifyZeroOperands(const Instruction *op) { if (op->getNumOperands() != 0) return op->emitOpError("requires zero operands"); return false; } -bool OpTrait::impl::verifyOneOperand(const OperationInst *op) { +bool OpTrait::impl::verifyOneOperand(const Instruction *op) { if (op->getNumOperands() != 1) return op->emitOpError("requires a single operand"); return false; } -bool OpTrait::impl::verifyNOperands(const OperationInst *op, +bool OpTrait::impl::verifyNOperands(const Instruction *op, unsigned numOperands) { if (op->getNumOperands() != numOperands) { return op->emitOpError("expected " + Twine(numOperands) + @@ -117,7 +117,7 @@ bool OpTrait::impl::verifyNOperands(const OperationInst *op, return false; } -bool OpTrait::impl::verifyAtLeastNOperands(const OperationInst *op, +bool OpTrait::impl::verifyAtLeastNOperands(const Instruction *op, unsigned numOperands) { if (op->getNumOperands() < numOperands) return op->emitOpError("expected " + Twine(numOperands) + @@ -137,7 +137,7 @@ static Type getTensorOrVectorElementType(Type type) { return type; } -bool OpTrait::impl::verifyOperandsAreIntegerLike(const OperationInst *op) { +bool OpTrait::impl::verifyOperandsAreIntegerLike(const Instruction *op) { for (auto *operand : op->getOperands()) { auto type = getTensorOrVectorElementType(operand->getType()); if (!type.isIntOrIndex()) @@ -146,7 +146,7 @@ bool OpTrait::impl::verifyOperandsAreIntegerLike(const OperationInst *op) { return false; } -bool OpTrait::impl::verifySameTypeOperands(const OperationInst *op) { +bool OpTrait::impl::verifySameTypeOperands(const Instruction *op) { // Zero or one operand always have the "same" type. unsigned nOperands = op->getNumOperands(); if (nOperands < 2) @@ -160,26 +160,26 @@ bool OpTrait::impl::verifySameTypeOperands(const OperationInst *op) { return false; } -bool OpTrait::impl::verifyZeroResult(const OperationInst *op) { +bool OpTrait::impl::verifyZeroResult(const Instruction *op) { if (op->getNumResults() != 0) return op->emitOpError("requires zero results"); return false; } -bool OpTrait::impl::verifyOneResult(const OperationInst *op) { +bool OpTrait::impl::verifyOneResult(const Instruction *op) { if (op->getNumResults() != 1) return op->emitOpError("requires one result"); return false; } -bool OpTrait::impl::verifyNResults(const OperationInst *op, +bool OpTrait::impl::verifyNResults(const Instruction *op, unsigned numOperands) { if (op->getNumResults() != numOperands) return op->emitOpError("expected " + Twine(numOperands) + " results"); return false; } -bool OpTrait::impl::verifyAtLeastNResults(const OperationInst *op, +bool OpTrait::impl::verifyAtLeastNResults(const Instruction *op, unsigned numOperands) { if (op->getNumResults() < numOperands) return op->emitOpError("expected " + Twine(numOperands) + @@ -208,7 +208,7 @@ static bool verifyShapeMatch(Type type1, Type type2) { return false; } -bool OpTrait::impl::verifySameOperandsAndResultShape(const OperationInst *op) { +bool OpTrait::impl::verifySameOperandsAndResultShape(const Instruction *op) { if (op->getNumOperands() == 0 || op->getNumResults() == 0) return true; @@ -226,7 +226,7 @@ bool OpTrait::impl::verifySameOperandsAndResultShape(const OperationInst *op) { return false; } -bool OpTrait::impl::verifySameOperandsAndResultType(const OperationInst *op) { +bool OpTrait::impl::verifySameOperandsAndResultType(const Instruction *op) { if (op->getNumOperands() == 0 || op->getNumResults() == 0) return true; @@ -245,8 +245,8 @@ bool OpTrait::impl::verifySameOperandsAndResultType(const OperationInst *op) { } static bool verifyBBArguments( - llvm::iterator_range<OperationInst::const_operand_iterator> operands, - const Block *destBB, const OperationInst *op) { + llvm::iterator_range<Instruction::const_operand_iterator> operands, + const Block *destBB, const Instruction *op) { unsigned operandCount = std::distance(operands.begin(), operands.end()); if (operandCount != destBB->getNumArguments()) return op->emitError("branch has " + Twine(operandCount) + @@ -262,7 +262,7 @@ static bool verifyBBArguments( return false; } -static bool verifyTerminatorSuccessors(const OperationInst *op) { +static bool verifyTerminatorSuccessors(const Instruction *op) { // Verify that the operands lines up with the BB arguments in the successor. const Function *fn = op->getFunction(); for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i) { @@ -275,7 +275,7 @@ static bool verifyTerminatorSuccessors(const OperationInst *op) { return false; } -bool OpTrait::impl::verifyIsTerminator(const OperationInst *op) { +bool OpTrait::impl::verifyIsTerminator(const Instruction *op) { const Block *block = op->getBlock(); // Verify that the operation is at the end of the respective parent block. if (!block || &block->back() != op) @@ -291,7 +291,7 @@ bool OpTrait::impl::verifyIsTerminator(const OperationInst *op) { return false; } -bool OpTrait::impl::verifyResultsAreBoolLike(const OperationInst *op) { +bool OpTrait::impl::verifyResultsAreBoolLike(const Instruction *op) { for (auto *result : op->getResults()) { auto elementType = getTensorOrVectorElementType(result->getType()); bool isBoolType = elementType.isInteger(1); @@ -302,7 +302,7 @@ bool OpTrait::impl::verifyResultsAreBoolLike(const OperationInst *op) { return false; } -bool OpTrait::impl::verifyResultsAreFloatLike(const OperationInst *op) { +bool OpTrait::impl::verifyResultsAreFloatLike(const Instruction *op) { for (auto *result : op->getResults()) { if (!getTensorOrVectorElementType(result->getType()).isa<FloatType>()) return op->emitOpError("requires a floating point type"); @@ -311,7 +311,7 @@ bool OpTrait::impl::verifyResultsAreFloatLike(const OperationInst *op) { return false; } -bool OpTrait::impl::verifyResultsAreIntegerLike(const OperationInst *op) { +bool OpTrait::impl::verifyResultsAreIntegerLike(const Instruction *op) { for (auto *result : op->getResults()) { auto type = getTensorOrVectorElementType(result->getType()); if (!type.isIntOrIndex()) @@ -344,7 +344,7 @@ bool impl::parseBinaryOp(OpAsmParser *parser, OperationState *result) { parser->addTypeToList(type, result->types); } -void impl::printBinaryOp(const OperationInst *op, OpAsmPrinter *p) { +void impl::printBinaryOp(const Instruction *op, OpAsmPrinter *p) { assert(op->getNumOperands() == 2 && "binary op should have two operands"); assert(op->getNumResults() == 1 && "binary op should have one result"); @@ -383,7 +383,7 @@ bool impl::parseCastOp(OpAsmParser *parser, OperationState *result) { parser->addTypeToList(dstType, result->types); } -void impl::printCastOp(const OperationInst *op, OpAsmPrinter *p) { +void impl::printCastOp(const Instruction *op, OpAsmPrinter *p) { *p << op->getName() << ' ' << *op->getOperand(0) << " : " << op->getOperand(0)->getType() << " to " << op->getResult(0)->getType(); } diff --git a/mlir/lib/IR/Value.cpp b/mlir/lib/IR/Value.cpp index 94873c5e330..fd6957a9fb3 100644 --- a/mlir/lib/IR/Value.cpp +++ b/mlir/lib/IR/Value.cpp @@ -22,7 +22,7 @@ using namespace mlir; /// If this value is the result of an Instruction, return the instruction /// that defines it. -OperationInst *Value::getDefiningInst() { +Instruction *Value::getDefiningInst() { if (auto *result = dyn_cast<InstResult>(this)) return result->getOwner(); return nullptr; diff --git a/mlir/test/mlir-tblgen/one-op-one-result.td b/mlir/test/mlir-tblgen/one-op-one-result.td index 3bdd9aa1b96..25023c89168 100644 --- a/mlir/test/mlir-tblgen/one-op-one-result.td +++ b/mlir/test/mlir-tblgen/one-op-one-result.td @@ -21,8 +21,8 @@ def : Pat<(X_AddOp $lhs, $rhs), (Y_AddOp $lhs, U:$rhs, T_Const_Attr:$x)>; // CHECK: struct GeneratedConvert0 : public RewritePattern // CHECK: RewritePattern("x.add", 1, context) -// CHECK: PatternMatchResult match(OperationInst * -// CHECK: void rewrite(OperationInst *op, std::unique_ptr<PatternState> +// CHECK: PatternMatchResult match(Instruction * +// CHECK: void rewrite(Instruction *op, std::unique_ptr<PatternState> // CHECK: PatternRewriter &rewriter) // CHECK: rewriter.replaceOpWithNewOp<AddOp>(op, op->getResult(0)->getType() // CHECK: void populateWithGenerated diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index 30e5989939f..5264a9e15f5 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -169,7 +169,7 @@ void OpEmitter::emit(const Record &def, raw_ostream &os) { os << "private:\n friend class ::mlir::Instruction;\n" << " explicit " << emitter.op.getCppClassName() - << "(const OperationInst* state) : Op(state) {}\n};\n"; + << "(const Instruction* state) : Op(state) {}\n};\n"; emitter.mapOverClassNamespaces( [&os](StringRef ns) { os << "} // end namespace " << ns << "\n"; }); } diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp index 7ca663071d8..1e6eecb3ed7 100644 --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -222,7 +222,7 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, int index, int depth, void PatternEmitter::emitMatchMethod(DagNode tree) { // Emit the heading. os << R"( - PatternMatchResult match(OperationInst *op0) const override { + PatternMatchResult match(Instruction *op0) const override { // TODO: This just handle 1 result if (op0->getNumResults() != 1) return matchFailure(); auto ctx = op0->getContext(); (void)ctx; @@ -280,7 +280,7 @@ void PatternEmitter::emitRewriteMethod() { PrintFatalError(loc, "only single op result supported"); os << R"( - void rewrite(OperationInst *op, std::unique_ptr<PatternState> state, + void rewrite(Instruction *op, std::unique_ptr<PatternState> state, PatternRewriter &rewriter) const override { auto& s = *static_cast<MatchedState *>(state.get()); )"; |

