diff options
| author | River Riddle <riverriddle@google.com> | 2019-02-04 10:38:47 -0800 |
|---|---|---|
| committer | jpienaar <jpienaar@google.com> | 2019-03-29 16:10:53 -0700 |
| commit | b499277fb648c44907443ce44ec6bcc6b7596039 (patch) | |
| tree | 7d61826527e189c2952cbe5b2498b2c43ab6f839 /mlir | |
| parent | 44e040dd635a8ce4b362cc81213f5e791b20830e (diff) | |
| download | bcm5719-llvm-b499277fb648c44907443ce44ec6bcc6b7596039.tar.gz bcm5719-llvm-b499277fb648c44907443ce44ec6bcc6b7596039.zip | |
Remove remaining usages of OperationInst in lib/Transforms.
PiperOrigin-RevId: 232323671
Diffstat (limited to 'mlir')
20 files changed, 251 insertions, 317 deletions
diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp index e471b6792c5..63a676d7b52 100644 --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -39,10 +39,10 @@ using namespace mlir; namespace { // TODO(riverriddle) Handle commutative operations. -struct SimpleOperationInfo : public llvm::DenseMapInfo<OperationInst *> { - static unsigned getHashValue(const OperationInst *op) { +struct SimpleOperationInfo : public llvm::DenseMapInfo<Instruction *> { + static unsigned getHashValue(const Instruction *op) { // Hash the operations based upon their: - // - OperationInst Name + // - Instruction Name // - Attributes // - Result Types // - Operands @@ -51,7 +51,7 @@ struct SimpleOperationInfo : public llvm::DenseMapInfo<OperationInst *> { hash_combine_range(op->result_type_begin(), op->result_type_end()), hash_combine_range(op->operand_begin(), op->operand_end())); } - static bool isEqual(const OperationInst *lhs, const OperationInst *rhs) { + static bool isEqual(const Instruction *lhs, const Instruction *rhs) { if (lhs == rhs) return true; if (lhs == getTombstoneKey() || lhs == getEmptyKey() || @@ -89,8 +89,8 @@ struct CSE : public FunctionPass { /// Shared implementation of operation elimination and scoped map definitions. using AllocatorTy = llvm::RecyclingAllocator< llvm::BumpPtrAllocator, - llvm::ScopedHashTableVal<OperationInst *, OperationInst *>>; - using ScopedMapTy = llvm::ScopedHashTable<OperationInst *, OperationInst *, + llvm::ScopedHashTableVal<Instruction *, Instruction *>>; + using ScopedMapTy = llvm::ScopedHashTable<Instruction *, Instruction *, SimpleOperationInfo, AllocatorTy>; /// Represents a single entry in the depth first traversal of a CFG. @@ -111,7 +111,7 @@ struct CSE : public FunctionPass { /// Attempt to eliminate a redundant operation. Returns true if the operation /// was marked for removal, false otherwise. - bool simplifyOperation(OperationInst *op); + bool simplifyOperation(Instruction *op); void simplifyBlock(Block *bb); @@ -122,14 +122,14 @@ private: ScopedMapTy knownValues; /// Operations marked as dead and to be erased. - std::vector<OperationInst *> opsToErase; + std::vector<Instruction *> opsToErase; }; } // end anonymous namespace char CSE::passID = 0; /// Attempt to eliminate a redundant operation. -bool CSE::simplifyOperation(OperationInst *op) { +bool CSE::simplifyOperation(Instruction *op) { // TODO(riverriddle) We currently only eliminate non side-effecting // operations. if (!op->hasNoSideEffect()) @@ -166,23 +166,16 @@ bool CSE::simplifyOperation(OperationInst *op) { void CSE::simplifyBlock(Block *bb) { for (auto &i : *bb) { - switch (i.getKind()) { - case Instruction::Kind::OperationInst: { - auto *opInst = cast<OperationInst>(&i); - - // If the operation is simplified, we don't process any held block lists. - if (simplifyOperation(opInst)) - continue; - - // Simplify any held blocks. - for (auto &blockList : opInst->getBlockLists()) { - for (auto &b : blockList) { - ScopedMapTy::ScopeTy scope(knownValues); - simplifyBlock(&b); - } + // If the operation is simplified, we don't process any held block lists. + if (simplifyOperation(&i)) + continue; + + // Simplify any held blocks. + for (auto &blockList : i.getBlockLists()) { + for (auto &b : blockList) { + ScopedMapTy::ScopeTy scope(knownValues); + simplifyBlock(&b); } - break; - } } } } diff --git a/mlir/lib/Transforms/ComposeAffineMaps.cpp b/mlir/lib/Transforms/ComposeAffineMaps.cpp index 4f960ea73af..4a6430dc9be 100644 --- a/mlir/lib/Transforms/ComposeAffineMaps.cpp +++ b/mlir/lib/Transforms/ComposeAffineMaps.cpp @@ -48,7 +48,7 @@ namespace { struct ComposeAffineMaps : public FunctionPass, InstWalker<ComposeAffineMaps> { explicit ComposeAffineMaps() : FunctionPass(&ComposeAffineMaps::passID) {} PassResult runOnFunction(Function *f) override; - void visitInstruction(OperationInst *opInst); + void visitInstruction(Instruction *opInst); SmallVector<OpPointer<AffineApplyOp>, 8> affineApplyOps; @@ -64,14 +64,12 @@ FunctionPass *mlir::createComposeAffineMapsPass() { } static bool affineApplyOp(const Instruction &inst) { - const auto &opInst = cast<OperationInst>(inst); - return opInst.isa<AffineApplyOp>(); + return inst.isa<AffineApplyOp>(); } -void ComposeAffineMaps::visitInstruction(OperationInst *opInst) { - if (auto afOp = opInst->dyn_cast<AffineApplyOp>()) { +void ComposeAffineMaps::visitInstruction(Instruction *opInst) { + if (auto afOp = opInst->dyn_cast<AffineApplyOp>()) affineApplyOps.push_back(afOp); - } } PassResult ComposeAffineMaps::runOnFunction(Function *f) { diff --git a/mlir/lib/Transforms/ConstantFold.cpp b/mlir/lib/Transforms/ConstantFold.cpp index 859d0012fac..54486cdb293 100644 --- a/mlir/lib/Transforms/ConstantFold.cpp +++ b/mlir/lib/Transforms/ConstantFold.cpp @@ -33,11 +33,11 @@ struct ConstantFold : public FunctionPass, InstWalker<ConstantFold> { // All constants in the function post folding. SmallVector<Value *, 8> existingConstants; // Operations that were folded and that need to be erased. - std::vector<OperationInst *> opInstsToErase; + std::vector<Instruction *> opInstsToErase; - bool foldOperation(OperationInst *op, + bool foldOperation(Instruction *op, SmallVectorImpl<Value *> &existingConstants); - void visitInstruction(OperationInst *op); + void visitInstruction(Instruction *op); PassResult runOnFunction(Function *f) override; static char passID; @@ -49,7 +49,7 @@ char ConstantFold::passID = 0; /// Attempt to fold the specified operation, updating the IR to match. If /// constants are found, we keep track of them in the existingConstants list. /// -void ConstantFold::visitInstruction(OperationInst *op) { +void ConstantFold::visitInstruction(Instruction *op) { // If this operation is an AffineForOp, then fold the bounds. if (auto forOp = op->dyn_cast<AffineForOp>()) { constantFoldBounds(forOp); diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp index 443e7750947..996416d9271 100644 --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -50,7 +50,7 @@ private: // Utility that looks up a list of value in the value remapping table. Returns // an empty vector if one of the values is not mapped yet. SmallVector<Value *, 4> - lookupValues(const llvm::iterator_range<OperationInst::const_operand_iterator> + lookupValues(const llvm::iterator_range<Instruction::const_operand_iterator> &operands); // Converts the given function to the dialect using hooks defined in @@ -61,13 +61,13 @@ private: // from `valueRemapping` and the converted blocks from `blockRemapping`, and // passes them to `converter->rewriteTerminator` function defined in the // pattern, together with `builder`. - bool convertOpWithSuccessors(DialectOpConversion *converter, - OperationInst *op, FuncBuilder &builder); + bool convertOpWithSuccessors(DialectOpConversion *converter, Instruction *op, + FuncBuilder &builder); // Converts an operation without successors. Extracts the converted operands // from `valueRemapping` and passes them to the `converter->rewrite` function // defined in the pattern, together with `builder`. - bool convertOp(DialectOpConversion *converter, OperationInst *op, + bool convertOp(DialectOpConversion *converter, Instruction *op, FuncBuilder &builder); // Converts a block by traversing its instructions sequentially, looking for @@ -104,8 +104,7 @@ private: } // end namespace mlir SmallVector<Value *, 4> impl::FunctionConversion::lookupValues( - const llvm::iterator_range<OperationInst::const_operand_iterator> - &operands) { + const llvm::iterator_range<Instruction::const_operand_iterator> &operands) { SmallVector<Value *, 4> remapped; remapped.reserve(llvm::size(operands)); for (const Value *operand : operands) { @@ -118,7 +117,7 @@ SmallVector<Value *, 4> impl::FunctionConversion::lookupValues( } bool impl::FunctionConversion::convertOpWithSuccessors( - DialectOpConversion *converter, OperationInst *op, FuncBuilder &builder) { + DialectOpConversion *converter, Instruction *op, FuncBuilder &builder) { SmallVector<Block *, 2> destinations; destinations.reserve(op->getNumSuccessors()); SmallVector<Value *, 4> operands = lookupValues(op->getOperands()); @@ -149,7 +148,7 @@ bool impl::FunctionConversion::convertOpWithSuccessors( } bool impl::FunctionConversion::convertOp(DialectOpConversion *converter, - OperationInst *op, + Instruction *op, FuncBuilder &builder) { auto operands = lookupValues(op->getOperands()); assert((!operands.empty() || op->getNumOperands() == 0) && @@ -174,24 +173,22 @@ bool impl::FunctionConversion::convertBlock( // Iterate over ops and convert them. for (Instruction &inst : *block) { - auto op = dyn_cast<OperationInst>(&inst); - if (!op) { - inst.emitError("unsupported instruction (For/If)"); + if (inst.getNumBlockLists() != 0) { + inst.emitError("unsupported region instruction"); return true; } // Find the first matching conversion and apply it. bool converted = false; for (auto *conversion : conversions) { - if (!conversion->match(op)) + if (!conversion->match(&inst)) continue; - if (op->isTerminator() && op->getNumSuccessors() > 0) { - if (convertOpWithSuccessors(conversion, op, builder)) - return true; - } else { - if (convertOp(conversion, op, builder)) + if (inst.isTerminator() && inst.getNumSuccessors() > 0) { + if (convertOpWithSuccessors(conversion, &inst, builder)) return true; + } else if (convertOp(conversion, &inst, builder)) { + return true; } converted = true; break; diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index 2bbb32036c2..92ae3767098 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -157,8 +157,7 @@ static void getMultiLevelStrides(const MemRefRegion ®ion, /// dynamic shaped memref's for now. `numParamLoopIVs` is the number of /// enclosing loop IVs of opInst (starting from the outermost) that the region /// is parametric on. -static bool getFullMemRefAsRegion(OperationInst *opInst, - unsigned numParamLoopIVs, +static bool getFullMemRefAsRegion(Instruction *opInst, unsigned numParamLoopIVs, MemRefRegion *region) { unsigned rank; if (auto loadOp = opInst->dyn_cast<LoadOp>()) { @@ -563,7 +562,7 @@ uint64_t DmaGeneration::runOnBlock(Block::iterator begin, Block::iterator end) { fastBufferMap.clear(); // Walk this range of instructions to gather all memory regions. - block->walk(begin, end, [&](OperationInst *opInst) { + block->walk(begin, end, [&](Instruction *opInst) { // Gather regions to allocate to buffers in faster memory space. if (auto loadOp = opInst->dyn_cast<LoadOp>()) { if (loadOp->getMemRefType().getMemorySpace() != slowMemorySpace) diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 304331320ac..d7d69e569e5 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -114,11 +114,11 @@ namespace { class LoopNestStateCollector : public InstWalker<LoopNestStateCollector> { public: SmallVector<OpPointer<AffineForOp>, 4> forOps; - SmallVector<OperationInst *, 4> loadOpInsts; - SmallVector<OperationInst *, 4> storeOpInsts; + SmallVector<Instruction *, 4> loadOpInsts; + SmallVector<Instruction *, 4> storeOpInsts; bool hasNonForRegion = false; - void visitInstruction(OperationInst *opInst) { + void visitInstruction(Instruction *opInst) { if (opInst->isa<AffineForOp>()) forOps.push_back(opInst->cast<AffineForOp>()); else if (opInst->getNumBlockLists() != 0) @@ -131,7 +131,7 @@ public: }; // TODO(b/117228571) Replace when this is modeled through side-effects/op traits -static bool isMemRefDereferencingOp(const OperationInst &op) { +static bool isMemRefDereferencingOp(const Instruction &op) { if (op.isa<LoadOp>() || op.isa<StoreOp>() || op.isa<DmaStartOp>() || op.isa<DmaWaitOp>()) return true; @@ -153,9 +153,9 @@ public: // The top-level statment which is (or contains) loads/stores. Instruction *inst; // List of load operations. - SmallVector<OperationInst *, 4> loads; + SmallVector<Instruction *, 4> loads; // List of store op insts. - SmallVector<OperationInst *, 4> stores; + SmallVector<Instruction *, 4> stores; Node(unsigned id, Instruction *inst) : id(id), inst(inst) {} // Returns the load op count for 'memref'. @@ -258,16 +258,13 @@ public: for (auto *storeOpInst : node->stores) { auto *memref = storeOpInst->cast<StoreOp>()->getMemRef(); auto *inst = memref->getDefiningInst(); - auto *opInst = dyn_cast_or_null<OperationInst>(inst); - // Return false if 'memref' is a function argument. - if (opInst == nullptr) + // Return false if 'memref' is a block argument. + if (!inst) return true; // Return false if any use of 'memref' escapes the function. - for (auto &use : memref->getUses()) { - auto *user = dyn_cast<OperationInst>(use.getOwner()); - if (!user || !isMemRefDereferencingOp(*user)) + for (auto &use : memref->getUses()) + if (!isMemRefDereferencingOp(*use.getOwner())) return true; - } } return false; } @@ -461,8 +458,8 @@ public: } // Adds ops in 'loads' and 'stores' to node at 'id'. - void addToNode(unsigned id, const SmallVectorImpl<OperationInst *> &loads, - const SmallVectorImpl<OperationInst *> &stores) { + void addToNode(unsigned id, const SmallVectorImpl<Instruction *> &loads, + const SmallVectorImpl<Instruction *> &stores) { Node *node = getNode(id); for (auto *loadOpInst : loads) node->loads.push_back(loadOpInst); @@ -509,7 +506,7 @@ bool MemRefDependenceGraph::init(Function *f) { DenseMap<Instruction *, unsigned> forToNodeMap; for (auto &inst : f->front()) { - if (auto forOp = cast<OperationInst>(&inst)->dyn_cast<AffineForOp>()) { + if (auto forOp = inst.dyn_cast<AffineForOp>()) { // Create graph node 'id' to represent top-level 'forOp' and record // all loads and store accesses it contains. LoopNestStateCollector collector; @@ -530,30 +527,28 @@ bool MemRefDependenceGraph::init(Function *f) { } forToNodeMap[&inst] = node.id; nodes.insert({node.id, node}); - } else if (auto *opInst = dyn_cast<OperationInst>(&inst)) { - if (auto loadOp = opInst->dyn_cast<LoadOp>()) { - // Create graph node for top-level load op. - Node node(nextNodeId++, &inst); - node.loads.push_back(opInst); - auto *memref = opInst->cast<LoadOp>()->getMemRef(); - memrefAccesses[memref].insert(node.id); - nodes.insert({node.id, node}); - } else if (auto storeOp = opInst->dyn_cast<StoreOp>()) { - // Create graph node for top-level store op. - Node node(nextNodeId++, &inst); - node.stores.push_back(opInst); - auto *memref = opInst->cast<StoreOp>()->getMemRef(); - memrefAccesses[memref].insert(node.id); - nodes.insert({node.id, node}); - } else if (opInst->getNumBlockLists() != 0) { - // Return false if another region is found (not currently supported). - return false; - } else if (opInst->getNumResults() > 0 && !opInst->use_empty()) { - // Create graph node for top-level producer of SSA values, which - // could be used by loop nest nodes. - Node node(nextNodeId++, &inst); - nodes.insert({node.id, node}); - } + } else if (auto loadOp = inst.dyn_cast<LoadOp>()) { + // Create graph node for top-level load op. + Node node(nextNodeId++, &inst); + node.loads.push_back(&inst); + auto *memref = inst.cast<LoadOp>()->getMemRef(); + memrefAccesses[memref].insert(node.id); + nodes.insert({node.id, node}); + } else if (auto storeOp = inst.dyn_cast<StoreOp>()) { + // Create graph node for top-level store op. + Node node(nextNodeId++, &inst); + node.stores.push_back(&inst); + auto *memref = inst.cast<StoreOp>()->getMemRef(); + memrefAccesses[memref].insert(node.id); + nodes.insert({node.id, node}); + } else if (inst.getNumBlockLists() != 0) { + // Return false if another region is found (not currently supported). + return false; + } else if (inst.getNumResults() > 0 && !inst.use_empty()) { + // Create graph node for top-level producer of SSA values, which + // could be used by loop nest nodes. + Node node(nextNodeId++, &inst); + nodes.insert({node.id, node}); } } @@ -563,12 +558,11 @@ bool MemRefDependenceGraph::init(Function *f) { const Node &node = idAndNode.second; if (!node.loads.empty() || !node.stores.empty()) continue; - auto *opInst = cast<OperationInst>(node.inst); + auto *opInst = node.inst; for (auto *value : opInst->getResults()) { for (auto &use : value->getUses()) { - auto *userOpInst = cast<OperationInst>(use.getOwner()); SmallVector<OpPointer<AffineForOp>, 4> loops; - getLoopIVs(*userOpInst, &loops); + getLoopIVs(*use.getOwner(), &loops); if (loops.empty()) continue; assert(forToNodeMap.count(loops[0]->getInstruction()) > 0); @@ -619,7 +613,7 @@ public: LoopNestStatsCollector(LoopNestStats *stats) : stats(stats) {} - void visitInstruction(OperationInst *opInst) { + void visitInstruction(Instruction *opInst) { auto forOp = opInst->dyn_cast<AffineForOp>(); if (!forOp) return; @@ -627,8 +621,7 @@ public: auto *forInst = forOp->getInstruction(); auto *parentInst = forOp->getInstruction()->getParentInst(); if (parentInst != nullptr) { - assert(cast<OperationInst>(parentInst)->isa<AffineForOp>() && - "Expected parent AffineForOp"); + assert(parentInst->isa<AffineForOp>() && "Expected parent AffineForOp"); // Add mapping to 'forOp' from its parent AffineForOp. stats->loopMap[parentInst].push_back(forOp); } @@ -637,8 +630,7 @@ public: unsigned count = 0; stats->opCountMap[forInst] = 0; for (auto &inst : *forOp->getBody()) { - if (!(cast<OperationInst>(inst).isa<AffineForOp>() || - cast<OperationInst>(inst).isa<AffineIfOp>())) + if (!(inst.isa<AffineForOp>() || inst.isa<AffineIfOp>())) ++count; } stats->opCountMap[forInst] = count; @@ -723,7 +715,7 @@ static Optional<uint64_t> getConstDifference(AffineMap lbMap, AffineMap ubMap) { // was encountered). // TODO(andydavis) Make this work with non-unit step loops. static bool buildSliceTripCountMap( - OperationInst *srcOpInst, ComputationSliceState *sliceState, + Instruction *srcOpInst, ComputationSliceState *sliceState, llvm::SmallDenseMap<Instruction *, uint64_t, 8> *tripCountMap) { SmallVector<OpPointer<AffineForOp>, 4> srcLoopIVs; getLoopIVs(*srcOpInst, &srcLoopIVs); @@ -755,10 +747,10 @@ static bool buildSliceTripCountMap( // adds them to 'dstLoads'. static void moveLoadsAccessingMemrefTo(Value *memref, - SmallVectorImpl<OperationInst *> *srcLoads, - SmallVectorImpl<OperationInst *> *dstLoads) { + SmallVectorImpl<Instruction *> *srcLoads, + SmallVectorImpl<Instruction *> *dstLoads) { dstLoads->clear(); - SmallVector<OperationInst *, 4> srcLoadsToKeep; + SmallVector<Instruction *, 4> srcLoadsToKeep; for (auto *load : *srcLoads) { if (load->cast<LoadOp>()->getMemRef() == memref) dstLoads->push_back(load); @@ -769,7 +761,7 @@ moveLoadsAccessingMemrefTo(Value *memref, } // Returns the innermost common loop depth for the set of operations in 'ops'. -static unsigned getInnermostCommonLoopDepth(ArrayRef<OperationInst *> ops) { +static unsigned getInnermostCommonLoopDepth(ArrayRef<Instruction *> ops) { unsigned numOps = ops.size(); assert(numOps > 0); @@ -797,10 +789,10 @@ static unsigned getInnermostCommonLoopDepth(ArrayRef<OperationInst *> ops) { // Returns the maximum loop depth at which no dependences between 'loadOpInsts' // and 'storeOpInsts' are satisfied. -static unsigned getMaxLoopDepth(ArrayRef<OperationInst *> loadOpInsts, - ArrayRef<OperationInst *> storeOpInsts) { +static unsigned getMaxLoopDepth(ArrayRef<Instruction *> loadOpInsts, + ArrayRef<Instruction *> storeOpInsts) { // Merge loads and stores into the same array. - SmallVector<OperationInst *, 2> ops(loadOpInsts.begin(), loadOpInsts.end()); + SmallVector<Instruction *, 2> ops(loadOpInsts.begin(), loadOpInsts.end()); ops.append(storeOpInsts.begin(), storeOpInsts.end()); // Compute the innermost common loop depth for loads and stores. @@ -913,7 +905,7 @@ unsigned getMemRefEltSizeInBytes(MemRefType memRefType) { // TODO(bondhugula): consider refactoring the common code from generateDma and // this one. static Value *createPrivateMemRef(OpPointer<AffineForOp> forOp, - OperationInst *srcStoreOpInst, + Instruction *srcStoreOpInst, unsigned dstLoopDepth, Optional<unsigned> fastMemorySpace, unsigned localBufSizeThreshold) { @@ -1061,9 +1053,9 @@ static uint64_t getSliceIterationCount( // *) Compares the total cost of the unfused loop nests to the min cost fused // loop nest computed in the previous step, and returns true if the latter // is lower. -static bool isFusionProfitable(OperationInst *srcOpInst, - ArrayRef<OperationInst *> dstLoadOpInsts, - ArrayRef<OperationInst *> dstStoreOpInsts, +static bool isFusionProfitable(Instruction *srcOpInst, + ArrayRef<Instruction *> dstLoadOpInsts, + ArrayRef<Instruction *> dstStoreOpInsts, ComputationSliceState *sliceState, unsigned *dstLoopDepth) { LLVM_DEBUG({ @@ -1174,7 +1166,7 @@ static bool isFusionProfitable(OperationInst *srcOpInst, computeCostMap[srcLoopIVs[numSrcLoopIVs - 1]->getInstruction()] = -1; for (auto *loadOp : dstLoadOpInsts) { auto *parentInst = loadOp->getParentInst(); - if (parentInst && cast<OperationInst>(parentInst)->isa<AffineForOp>()) + if (parentInst && parentInst->isa<AffineForOp>()) computeCostMap[parentInst] = -1; } } @@ -1393,11 +1385,11 @@ public: // Get 'dstNode' into which to attempt fusion. auto *dstNode = mdg->getNode(dstId); // Skip if 'dstNode' is not a loop nest. - if (!cast<OperationInst>(dstNode->inst)->isa<AffineForOp>()) + if (!dstNode->inst->isa<AffineForOp>()) continue; - SmallVector<OperationInst *, 4> loads = dstNode->loads; - SmallVector<OperationInst *, 4> dstLoadOpInsts; + SmallVector<Instruction *, 4> loads = dstNode->loads; + SmallVector<Instruction *, 4> dstLoadOpInsts; DenseSet<Value *> visitedMemrefs; while (!loads.empty()) { // Get memref of load on top of the stack. @@ -1426,7 +1418,7 @@ public: // Get 'srcNode' from which to attempt fusion into 'dstNode'. auto *srcNode = mdg->getNode(srcId); // Skip if 'srcNode' is not a loop nest. - if (!cast<OperationInst>(srcNode->inst)->isa<AffineForOp>()) + if (!srcNode->inst->isa<AffineForOp>()) continue; // Skip if 'srcNode' has more than one store to any memref. // TODO(andydavis) Support fusing multi-output src loop nests. @@ -1454,7 +1446,7 @@ public: // Get unique 'srcNode' store op. auto *srcStoreOpInst = srcNode->stores.front(); // Gather 'dstNode' store ops to 'memref'. - SmallVector<OperationInst *, 2> dstStoreOpInsts; + SmallVector<Instruction *, 2> dstStoreOpInsts; for (auto *storeOpInst : dstNode->stores) if (storeOpInst->cast<StoreOp>()->getMemRef() == memref) dstStoreOpInsts.push_back(storeOpInst); @@ -1472,8 +1464,7 @@ public: srcStoreOpInst, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState); if (sliceLoopNest != nullptr) { // Move 'dstAffineForOp' before 'insertPointInst' if needed. - auto dstAffineForOp = - cast<OperationInst>(dstNode->inst)->cast<AffineForOp>(); + auto dstAffineForOp = dstNode->inst->cast<AffineForOp>(); if (insertPointInst != dstAffineForOp->getInstruction()) { dstAffineForOp->getInstruction()->moveBefore(insertPointInst); } @@ -1488,7 +1479,7 @@ public: promoteIfSingleIteration(forOp); } // Create private memref for 'memref' in 'dstAffineForOp'. - SmallVector<OperationInst *, 4> storesForMemref; + SmallVector<Instruction *, 4> storesForMemref; for (auto *storeOpInst : sliceCollector.storeOpInsts) { if (storeOpInst->cast<StoreOp>()->getMemRef() == memref) storesForMemref.push_back(storeOpInst); @@ -1541,9 +1532,8 @@ public: continue; // Use list expected to match the dep graph info. auto *inst = memref->getDefiningInst(); - auto *opInst = dyn_cast_or_null<OperationInst>(inst); - if (opInst && opInst->isa<AllocOp>()) - opInst->erase(); + if (inst && inst->isa<AllocOp>()) + inst->erase(); } } }; diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp index f1ee7fd1853..8b368e5f182 100644 --- a/mlir/lib/Transforms/LoopTiling.cpp +++ b/mlir/lib/Transforms/LoopTiling.cpp @@ -237,14 +237,13 @@ getTileableBands(Function *f, do { band.push_back(currInst); } while (currInst->getBody()->getInstructions().size() == 1 && - (currInst = cast<OperationInst>(currInst->getBody()->front()) - .dyn_cast<AffineForOp>())); + (currInst = currInst->getBody()->front().dyn_cast<AffineForOp>())); bands->push_back(band); }; for (auto &block : *f) for (auto &inst : block) - if (auto forOp = cast<OperationInst>(inst).dyn_cast<AffineForOp>()) + if (auto forOp = inst.dyn_cast<AffineForOp>()) getMaximalPerfectLoopNest(forOp); } diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp index 9c9952d31ca..b1e15ccb07b 100644 --- a/mlir/lib/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Transforms/LoopUnroll.cpp @@ -113,7 +113,7 @@ PassResult LoopUnroll::runOnFunction(Function *f) { return hasInnerLoops; } - bool walkPostOrder(OperationInst *opInst) { + bool walkPostOrder(Instruction *opInst) { bool hasInnerLoops = false; for (auto &blockList : opInst->getBlockLists()) for (auto &block : blockList) @@ -140,7 +140,7 @@ PassResult LoopUnroll::runOnFunction(Function *f) { const unsigned minTripCount; ShortLoopGatherer(unsigned minTripCount) : minTripCount(minTripCount) {} - void visitInstruction(OperationInst *opInst) { + void visitInstruction(Instruction *opInst) { auto forOp = opInst->dyn_cast<AffineForOp>(); if (!forOp) return; diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp index d87f9d5dc14..74c54fde047 100644 --- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp +++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp @@ -100,8 +100,7 @@ PassResult LoopUnrollAndJam::runOnFunction(Function *f) { // any for Inst. auto &entryBlock = f->front(); if (!entryBlock.empty()) - if (auto forOp = - cast<OperationInst>(entryBlock.front()).dyn_cast<AffineForOp>()) + if (auto forOp = entryBlock.front().dyn_cast<AffineForOp>()) runOnAffineForOp(forOp); return success(); @@ -149,12 +148,12 @@ bool mlir::loopUnrollJamByFactor(OpPointer<AffineForOp> forOp, void walk(InstListType::iterator Start, InstListType::iterator End) { for (auto it = Start; it != End;) { auto subBlockStart = it; - while (it != End && !cast<OperationInst>(it)->isa<AffineForOp>()) + while (it != End && !it->isa<AffineForOp>()) ++it; if (it != subBlockStart) subBlocks.push_back({subBlockStart, std::prev(it)}); // Process all for insts that appear next. - while (it != End && cast<OperationInst>(it)->isa<AffineForOp>()) + while (it != End && it->isa<AffineForOp>()) walk(&*it++); } } @@ -206,8 +205,7 @@ bool mlir::loopUnrollJamByFactor(OpPointer<AffineForOp> forOp, // Insert the cleanup loop right after 'forOp'. FuncBuilder builder(forInst->getBlock(), std::next(Block::iterator(forInst))); - auto cleanupAffineForOp = - cast<OperationInst>(builder.clone(*forInst))->cast<AffineForOp>(); + auto cleanupAffineForOp = builder.clone(*forInst)->cast<AffineForOp>(); cleanupAffineForOp->setLowerBoundMap( getCleanupLoopLowerBound(forOp, unrollJamFactor, &builder)); diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp index 08c8188fada..88ccc90c18b 100644 --- a/mlir/lib/Transforms/LowerAffine.cpp +++ b/mlir/lib/Transforms/LowerAffine.cpp @@ -616,23 +616,21 @@ PassResult LowerAffinePass::runOnFunction(Function *function) { // Collect all the For instructions as well as AffineIfOps and AffineApplyOps. // We do this as a prepass to avoid invalidating the walker with our rewrite. function->walk([&](Instruction *inst) { - auto op = cast<OperationInst>(inst); - if (op->isa<AffineApplyOp>() || op->isa<AffineForOp>() || - op->isa<AffineIfOp>()) + if (inst->isa<AffineApplyOp>() || inst->isa<AffineForOp>() || + inst->isa<AffineIfOp>()) instsToRewrite.push_back(inst); }); // Rewrite all of the ifs and fors. We walked the instructions in preorder, // so we know that we will rewrite them in the same order. for (auto *inst : instsToRewrite) { - auto op = cast<OperationInst>(inst); - if (auto ifOp = op->dyn_cast<AffineIfOp>()) { + if (auto ifOp = inst->dyn_cast<AffineIfOp>()) { if (lowerAffineIf(ifOp)) return failure(); - } else if (auto forOp = op->dyn_cast<AffineForOp>()) { + } else if (auto forOp = inst->dyn_cast<AffineForOp>()) { if (lowerAffineFor(forOp)) return failure(); - } else if (lowerAffineApply(op->cast<AffineApplyOp>())) { + } else if (lowerAffineApply(inst->cast<AffineApplyOp>())) { return failure(); } } diff --git a/mlir/lib/Transforms/LowerVectorTransfers.cpp b/mlir/lib/Transforms/LowerVectorTransfers.cpp index 7f1e9b157d8..63fb45db9c5 100644 --- a/mlir/lib/Transforms/LowerVectorTransfers.cpp +++ b/mlir/lib/Transforms/LowerVectorTransfers.cpp @@ -401,13 +401,12 @@ public: explicit VectorTransferExpander(MLIRContext *context) : MLLoweringPattern(VectorTransferOpTy::getOperationName(), 1, context) {} - PatternMatchResult match(OperationInst *op) const override { + PatternMatchResult match(Instruction *op) const override { if (m_Op<VectorTransferOpTy>().match(op)) return matchSuccess(); return matchFailure(); } - void rewriteOpInst(OperationInst *op, - MLFuncGlobalLoweringState *funcWiseState, + void rewriteOpInst(Instruction *op, MLFuncGlobalLoweringState *funcWiseState, std::unique_ptr<PatternState> opState, MLFuncLoweringRewriter *rewriter) const override { VectorTransferRewriter<VectorTransferOpTy>( diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index f2dae11112b..f55c2154f08 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -246,8 +246,8 @@ static SmallVector<unsigned, 8> delinearize(unsigned linearIndex, return res; } -static OperationInst * -instantiate(FuncBuilder *b, OperationInst *opInst, VectorType hwVectorType, +static Instruction * +instantiate(FuncBuilder *b, Instruction *opInst, VectorType hwVectorType, DenseMap<const Value *, Value *> *substitutionsMap); /// Not all Values belong to a program slice scoped within the immediately @@ -391,7 +391,7 @@ reindexAffineIndices(FuncBuilder *b, VectorType hwVectorType, /// - constant splat is replaced by constant splat of `hwVectorType`. /// TODO(ntv): add more substitutions on a per-need basis. static SmallVector<NamedAttribute, 1> -materializeAttributes(OperationInst *opInst, VectorType hwVectorType) { +materializeAttributes(Instruction *opInst, VectorType hwVectorType) { SmallVector<NamedAttribute, 1> res; for (auto a : opInst->getAttrs()) { if (auto splat = a.second.dyn_cast<SplatElementsAttr>()) { @@ -411,8 +411,8 @@ materializeAttributes(OperationInst *opInst, VectorType hwVectorType) { /// substitutionsMap. /// /// If the underlying substitution fails, this fails too and returns nullptr. -static OperationInst * -instantiate(FuncBuilder *b, OperationInst *opInst, VectorType hwVectorType, +static Instruction * +instantiate(FuncBuilder *b, Instruction *opInst, VectorType hwVectorType, DenseMap<const Value *, Value *> *substitutionsMap) { assert(!opInst->isa<VectorTransferReadOp>() && "Should call the function specialized for VectorTransferReadOp"); @@ -488,7 +488,7 @@ static AffineMap projectedPermutationMap(VectorTransferOpTy *transfer, /// `hwVectorType` int the covering of the super-vector type. For a more /// detailed description of the problem, see the description of /// reindexAffineIndices. -static OperationInst * +static Instruction * instantiate(FuncBuilder *b, VectorTransferReadOp *read, VectorType hwVectorType, ArrayRef<unsigned> hwVectorInstance, DenseMap<const Value *, Value *> *substitutionsMap) { @@ -512,7 +512,7 @@ instantiate(FuncBuilder *b, VectorTransferReadOp *read, VectorType hwVectorType, /// `hwVectorType` int the covering of th3e super-vector type. For a more /// detailed description of the problem, see the description of /// reindexAffineIndices. -static OperationInst * +static Instruction * instantiate(FuncBuilder *b, VectorTransferWriteOp *write, VectorType hwVectorType, ArrayRef<unsigned> hwVectorInstance, DenseMap<const Value *, Value *> *substitutionsMap) { @@ -555,21 +555,20 @@ static bool instantiateMaterialization(Instruction *inst, // Create a builder here for unroll-and-jam effects. FuncBuilder b(inst); - auto *opInst = cast<OperationInst>(inst); // AffineApplyOp are ignored: instantiating the proper vector op will take // care of AffineApplyOps by composing them properly. - if (opInst->isa<AffineApplyOp>()) { + if (inst->isa<AffineApplyOp>()) { return false; } - if (opInst->getNumBlockLists() != 0) + if (inst->getNumBlockLists() != 0) return inst->emitError("NYI path Op with region"); - if (auto write = opInst->dyn_cast<VectorTransferWriteOp>()) { + if (auto write = inst->dyn_cast<VectorTransferWriteOp>()) { auto *clone = instantiate(&b, write, state->hwVectorType, state->hwVectorInstance, state->substitutionsMap); return clone == nullptr; } - if (auto read = opInst->dyn_cast<VectorTransferReadOp>()) { + if (auto read = inst->dyn_cast<VectorTransferReadOp>()) { auto *clone = instantiate(&b, read, state->hwVectorType, state->hwVectorInstance, state->substitutionsMap); if (!clone) { @@ -582,19 +581,19 @@ static bool instantiateMaterialization(Instruction *inst, // The only op with 0 results reaching this point must, by construction, be // VectorTransferWriteOps and have been caught above. Ops with >= 2 results // are not yet supported. So just support 1 result. - if (opInst->getNumResults() != 1) { + if (inst->getNumResults() != 1) { return inst->emitError("NYI: ops with != 1 results"); } - if (opInst->getResult(0)->getType() != state->superVectorType) { + if (inst->getResult(0)->getType() != state->superVectorType) { return inst->emitError("Op does not return a supervector."); } auto *clone = - instantiate(&b, opInst, state->hwVectorType, state->substitutionsMap); + instantiate(&b, inst, state->hwVectorType, state->substitutionsMap); if (!clone) { return true; } state->substitutionsMap->insert( - std::make_pair(opInst->getResult(0), clone->getResult(0))); + std::make_pair(inst->getResult(0), clone->getResult(0))); return false; } @@ -645,7 +644,7 @@ static bool emitSlice(MaterializationState *state, } LLVM_DEBUG(dbgs() << "\nMLFunction is now\n"); - LLVM_DEBUG(cast<OperationInst>((*slice)[0])->getFunction()->print(dbgs())); + LLVM_DEBUG((*slice)[0]->getFunction()->print(dbgs())); // slice are topologically sorted, we can just erase them in reverse // order. Reverse iterator does not just work simply with an operator* @@ -677,7 +676,7 @@ static bool emitSlice(MaterializationState *state, /// scope. /// TODO(ntv): please document return value. static bool materialize(Function *f, - const SetVector<OperationInst *> &terminators, + const SetVector<Instruction *> &terminators, MaterializationState *state) { DenseSet<Instruction *> seen; DominanceInfo domInfo(f); @@ -757,18 +756,17 @@ PassResult MaterializeVectorsPass::runOnFunction(Function *f) { // Capture terminators; i.e. vector_transfer_write ops involving a strict // super-vector of subVectorType. auto filter = [subVectorType](const Instruction &inst) { - const auto &opInst = cast<OperationInst>(inst); - if (!opInst.isa<VectorTransferWriteOp>()) { + if (!inst.isa<VectorTransferWriteOp>()) { return false; } - return matcher::operatesOnSuperVectors(opInst, subVectorType); + return matcher::operatesOnSuperVectors(inst, subVectorType); }; auto pat = Op(filter); SmallVector<NestedMatch, 8> matches; pat.match(f, &matches); - SetVector<OperationInst *> terminators; + SetVector<Instruction *> terminators; for (auto m : matches) { - terminators.insert(cast<OperationInst>(m.getMatchedInstruction())); + terminators.insert(m.getMatchedInstruction()); } auto fail = materialize(f, terminators, &state); diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp index b9386c384dd..b2b69dc7b6d 100644 --- a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp +++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp @@ -75,12 +75,12 @@ struct MemRefDataFlowOpt : public FunctionPass, InstWalker<MemRefDataFlowOpt> { PassResult runOnFunction(Function *f) override; - void visitInstruction(OperationInst *opInst); + void visitInstruction(Instruction *opInst); // A list of memref's that are potentially dead / could be eliminated. SmallPtrSet<Value *, 4> memrefsToErase; // Load op's whose results were replaced by those forwarded from stores. - std::vector<OperationInst *> loadOpsToErase; + std::vector<Instruction *> loadOpsToErase; DominanceInfo *domInfo = nullptr; PostDominanceInfo *postDomInfo = nullptr; @@ -100,22 +100,22 @@ FunctionPass *mlir::createMemRefDataFlowOptPass() { // This is a straightforward implementation not optimized for speed. Optimize // this in the future if needed. -void MemRefDataFlowOpt::visitInstruction(OperationInst *opInst) { - OperationInst *lastWriteStoreOp = nullptr; +void MemRefDataFlowOpt::visitInstruction(Instruction *opInst) { + Instruction *lastWriteStoreOp = nullptr; auto loadOp = opInst->dyn_cast<LoadOp>(); if (!loadOp) return; - OperationInst *loadOpInst = opInst; + Instruction *loadOpInst = opInst; // First pass over the use list to get minimum number of surrounding // loops common between the load op and the store op, with min taken across // all store ops. - SmallVector<OperationInst *, 8> storeOps; + SmallVector<Instruction *, 8> storeOps; unsigned minSurroundingLoops = getNestingDepth(*loadOpInst); for (InstOperand &use : loadOp->getMemRef()->getUses()) { - auto storeOp = cast<OperationInst>(use.getOwner())->dyn_cast<StoreOp>(); + auto storeOp = use.getOwner()->dyn_cast<StoreOp>(); if (!storeOp) continue; auto *storeOpInst = storeOp->getInstruction(); @@ -131,11 +131,11 @@ void MemRefDataFlowOpt::visitInstruction(OperationInst *opInst) { // and loadOp. // The list of store op candidates for forwarding - need to satisfy the // conditions listed at the top. - SmallVector<OperationInst *, 8> fwdingCandidates; + SmallVector<Instruction *, 8> fwdingCandidates; // Store ops that have a dependence into the load (even if they aren't // forwarding candidates). Each forwarding candidate will be checked for a // post-dominance on these. 'fwdingCandidates' are a subset of depSrcStores. - SmallVector<OperationInst *, 8> depSrcStores; + SmallVector<Instruction *, 8> depSrcStores; for (auto *storeOpInst : storeOps) { MemRefAccess srcAccess(storeOpInst); MemRefAccess destAccess(loadOpInst); @@ -197,7 +197,7 @@ void MemRefDataFlowOpt::visitInstruction(OperationInst *opInst) { // that postdominates all 'depSrcStores' (if such a store exists) is the // unique store providing the value to the load, i.e., provably the last // writer to that memref loc. - if (llvm::all_of(depSrcStores, [&](OperationInst *depStore) { + if (llvm::all_of(depSrcStores, [&](Instruction *depStore) { return postDomInfo->postDominates(storeOpInst, depStore); })) { lastWriteStoreOp = storeOpInst; @@ -246,24 +246,22 @@ PassResult MemRefDataFlowOpt::runOnFunction(Function *f) { // to do this as well, but we'll do it here since we collected these anyway. for (auto *memref : memrefsToErase) { // If the memref hasn't been alloc'ed in this function, skip. - OperationInst *defInst = memref->getDefiningInst(); + Instruction *defInst = memref->getDefiningInst(); if (!defInst || !defInst->isa<AllocOp>()) // TODO(mlir-team): if the memref was returned by a 'call' instruction, we // could still erase it if the call had no side-effects. continue; if (std::any_of(memref->use_begin(), memref->use_end(), [&](InstOperand &use) { - auto *ownerInst = cast<OperationInst>(use.getOwner()); + auto *ownerInst = use.getOwner(); return (!ownerInst->isa<StoreOp>() && !ownerInst->isa<DeallocOp>()); })) continue; // Erase all stores, the dealloc, and the alloc on the memref. - for (auto it = memref->use_begin(), e = memref->use_end(); it != e;) { - auto &use = *(it++); - cast<OperationInst>(use.getOwner())->erase(); - } + for (auto &use : llvm::make_early_inc_range(memref->getUses())) + use.getOwner()->erase(); defInst->erase(); } diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index 8d13800160d..ba3be5e95f4 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -61,7 +61,7 @@ FunctionPass *mlir::createPipelineDataTransferPass() { // Returns the position of the tag memref operand given a DMA instruction. // Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are // added. TODO(b/117228571) -static unsigned getTagMemRefPos(const OperationInst &dmaInst) { +static unsigned getTagMemRefPos(const Instruction &dmaInst) { assert(dmaInst.isa<DmaStartOp>() || dmaInst.isa<DmaWaitOp>()); if (dmaInst.isa<DmaStartOp>()) { // Second to last operand. @@ -142,7 +142,7 @@ PassResult PipelineDataTransfer::runOnFunction(Function *f) { // deleted and replaced by a prologue, a new steady-state loop and an // epilogue). forOps.clear(); - f->walkPostOrder([&](OperationInst *opInst) { + f->walkPostOrder([&](Instruction *opInst) { if (auto forOp = opInst->dyn_cast<AffineForOp>()) forOps.push_back(forOp); }); @@ -180,33 +180,26 @@ static bool checkTagMatch(OpPointer<DmaStartOp> startOp, // Identify matching DMA start/finish instructions to overlap computation with. static void findMatchingStartFinishInsts( OpPointer<AffineForOp> forOp, - SmallVectorImpl<std::pair<OperationInst *, OperationInst *>> - &startWaitPairs) { + SmallVectorImpl<std::pair<Instruction *, Instruction *>> &startWaitPairs) { // Collect outgoing DMA instructions - needed to check for dependences below. SmallVector<OpPointer<DmaStartOp>, 4> outgoingDmaOps; for (auto &inst : *forOp->getBody()) { - auto *opInst = dyn_cast<OperationInst>(&inst); - if (!opInst) - continue; OpPointer<DmaStartOp> dmaStartOp; - if ((dmaStartOp = opInst->dyn_cast<DmaStartOp>()) && + if ((dmaStartOp = inst.dyn_cast<DmaStartOp>()) && dmaStartOp->isSrcMemorySpaceFaster()) outgoingDmaOps.push_back(dmaStartOp); } - SmallVector<OperationInst *, 4> dmaStartInsts, dmaFinishInsts; + SmallVector<Instruction *, 4> dmaStartInsts, dmaFinishInsts; for (auto &inst : *forOp->getBody()) { - auto *opInst = dyn_cast<OperationInst>(&inst); - if (!opInst) - continue; // Collect DMA finish instructions. - if (opInst->isa<DmaWaitOp>()) { - dmaFinishInsts.push_back(opInst); + if (inst.isa<DmaWaitOp>()) { + dmaFinishInsts.push_back(&inst); continue; } OpPointer<DmaStartOp> dmaStartOp; - if (!(dmaStartOp = opInst->dyn_cast<DmaStartOp>())) + if (!(dmaStartOp = inst.dyn_cast<DmaStartOp>())) continue; // Only DMAs incoming into higher memory spaces are pipelined for now. // TODO(bondhugula): handle outgoing DMA pipelining. @@ -236,7 +229,7 @@ static void findMatchingStartFinishInsts( } } if (!escapingUses) - dmaStartInsts.push_back(opInst); + dmaStartInsts.push_back(&inst); } // For each start instruction, we look for a matching finish instruction. @@ -262,7 +255,7 @@ PipelineDataTransfer::runOnAffineForOp(OpPointer<AffineForOp> forOp) { return success(); } - SmallVector<std::pair<OperationInst *, OperationInst *>, 4> startWaitPairs; + SmallVector<std::pair<Instruction *, Instruction *>, 4> startWaitPairs; findMatchingStartFinishInsts(forOp, startWaitPairs); if (startWaitPairs.empty()) { @@ -335,7 +328,7 @@ PipelineDataTransfer::runOnAffineForOp(OpPointer<AffineForOp> forOp) { } else { // If a slice wasn't created, the reachable affine_apply op's from its // operands are the ones that go with it. - SmallVector<OperationInst *, 4> affineApplyInsts; + SmallVector<Instruction *, 4> affineApplyInsts; SmallVector<Value *, 4> operands(dmaStartInst->getOperands()); getReachableAffineApplyOps(operands, affineApplyInsts); for (const auto *inst : affineApplyInsts) { @@ -356,13 +349,13 @@ PipelineDataTransfer::runOnAffineForOp(OpPointer<AffineForOp> forOp) { for (auto &inst : *forOp->getBody()) { assert(instShiftMap.find(&inst) != instShiftMap.end()); shifts[s++] = instShiftMap[&inst]; - LLVM_DEBUG( - // Tagging instructions with shifts for debugging purposes. - if (auto *opInst = dyn_cast<OperationInst>(&inst)) { - FuncBuilder b(opInst); - opInst->setAttr(b.getIdentifier("shift"), - b.getI64IntegerAttr(shifts[s - 1])); - }); + + // Tagging instructions with shifts for debugging purposes. + LLVM_DEBUG({ + FuncBuilder b(&inst); + inst.setAttr(b.getIdentifier("shift"), + b.getI64IntegerAttr(shifts[s - 1])); + }); } if (!isInstwiseShiftValid(forOp, shifts)) { diff --git a/mlir/lib/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Transforms/SimplifyAffineStructures.cpp index a9fcfc5bd11..29509911e31 100644 --- a/mlir/lib/Transforms/SimplifyAffineStructures.cpp +++ b/mlir/lib/Transforms/SimplifyAffineStructures.cpp @@ -64,7 +64,7 @@ static IntegerSet simplifyIntegerSet(IntegerSet set) { } PassResult SimplifyAffineStructures::runOnFunction(Function *f) { - f->walk([&](OperationInst *opInst) { + f->walk([&](Instruction *opInst) { for (auto attr : opInst->getAttrs()) { if (auto mapAttr = attr.second.dyn_cast<AffineMapAttr>()) { MutableAffineMap mMap(mapAttr.getValue()); diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index 790f971bb58..45c57e2f307 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -38,13 +38,13 @@ public: worklist.reserve(64); // Add all operations to the worklist. - fn->walk([&](OperationInst *inst) { addToWorklist(inst); }); + fn->walk([&](Instruction *inst) { addToWorklist(inst); }); } /// Perform the rewrites. void simplifyFunction(); - void addToWorklist(OperationInst *op) { + void addToWorklist(Instruction *op) { // Check to see if the worklist already contains this op. if (worklistMap.count(op)) return; @@ -53,7 +53,7 @@ public: worklist.push_back(op); } - OperationInst *popFromWorklist() { + Instruction *popFromWorklist() { auto *op = worklist.back(); worklist.pop_back(); @@ -65,7 +65,7 @@ public: /// If the specified operation is in the worklist, remove it. If not, this is /// a no-op. - void removeFromWorklist(OperationInst *op) { + void removeFromWorklist(Instruction *op) { auto it = worklistMap.find(op); if (it != worklistMap.end()) { assert(worklist[it->second] == op && "malformed worklist data structure"); @@ -77,7 +77,7 @@ public: protected: // Implement the hook for creating operations, and make sure that newly // created ops are added to the worklist for processing. - OperationInst *createOperation(const OperationState &state) override { + Instruction *createOperation(const OperationState &state) override { auto *result = builder.createOperation(state); addToWorklist(result); return result; @@ -85,20 +85,18 @@ protected: // If an operation is about to be removed, make sure it is not in our // worklist anymore because we'd get dangling references to it. - void notifyOperationRemoved(OperationInst *op) override { + void notifyOperationRemoved(Instruction *op) override { removeFromWorklist(op); } // When the root of a pattern is about to be replaced, it can trigger // simplifications to its users - make sure to add them to the worklist // before the root is changed. - void notifyRootReplaced(OperationInst *op) override { + void notifyRootReplaced(Instruction *op) override { for (auto *result : op->getResults()) // TODO: Add a result->getUsers() iterator. - for (auto &user : result->getUses()) { - if (auto *op = dyn_cast<OperationInst>(user.getOwner())) - addToWorklist(op); - } + for (auto &user : result->getUses()) + addToWorklist(user.getOwner()); // TODO: Walk the operand list dropping them as we go. If any of them // drop to zero uses, then add them to the worklist to allow them to be @@ -116,13 +114,13 @@ private: /// need to be revisited, plus their index in the worklist. This allows us to /// efficiently remove operations from the worklist when they are erased from /// the function, even if they aren't the root of a pattern. - std::vector<OperationInst *> worklist; - DenseMap<OperationInst *, unsigned> worklistMap; + std::vector<Instruction *> worklist; + DenseMap<Instruction *, unsigned> worklistMap; /// As part of canonicalization, we move constants to the top of the entry /// block of the current function and de-duplicate them. This keeps track of /// constants we have done this for. - DenseMap<std::pair<Attribute, Type>, OperationInst *> uniquedConstants; + DenseMap<std::pair<Attribute, Type>, Instruction *> uniquedConstants; }; }; // end anonymous namespace @@ -229,10 +227,8 @@ void GreedyPatternRewriteDriver::simplifyFunction() { // revisit them. // // TODO: Add a result->getUsers() iterator. - for (auto &operand : op->getResult(i)->getUses()) { - if (auto *op = dyn_cast<OperationInst>(operand.getOwner())) - addToWorklist(op); - } + for (auto &operand : op->getResult(i)->getUses()) + addToWorklist(operand.getOwner()); res->replaceAllUsesWith(cstValue); } @@ -267,10 +263,8 @@ void GreedyPatternRewriteDriver::simplifyFunction() { if (res->use_empty()) // ignore dead uses. continue; - for (auto &operand : op->getResult(i)->getUses()) { - if (auto *op = dyn_cast<OperationInst>(operand.getOwner())) - addToWorklist(op); - } + for (auto &operand : op->getResult(i)->getUses()) + addToWorklist(operand.getOwner()); res->replaceAllUsesWith(resultValues[i]); } } diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index 153557de04a..5bf17989bef 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -101,7 +101,7 @@ bool mlir::promoteIfSingleIteration(OpPointer<AffineForOp> forOp) { // Replaces all IV uses to its single iteration value. auto *iv = forOp->getInductionVar(); - OperationInst *forInst = forOp->getInstruction(); + Instruction *forInst = forOp->getInstruction(); if (!iv->use_empty()) { if (forOp->hasConstantLowerBound()) { auto *mlFunc = forInst->getFunction(); @@ -135,7 +135,7 @@ bool mlir::promoteIfSingleIteration(OpPointer<AffineForOp> forOp) { /// their body into the containing Block. void mlir::promoteSingleIterationLoops(Function *f) { // Gathers all innermost loops through a post order pruned walk. - f->walkPostOrder([](OperationInst *inst) { + f->walkPostOrder([](Instruction *inst) { if (auto forOp = inst->dyn_cast<AffineForOp>()) promoteIfSingleIteration(forOp); }); @@ -394,11 +394,10 @@ bool mlir::loopUnrollByFactor(OpPointer<AffineForOp> forOp, return false; // Generate the cleanup loop if trip count isn't a multiple of unrollFactor. - OperationInst *forInst = forOp->getInstruction(); + Instruction *forInst = forOp->getInstruction(); if (getLargestDivisorOfTripCount(forOp) % unrollFactor != 0) { FuncBuilder builder(forInst->getBlock(), ++Block::iterator(forInst)); - auto cleanupForInst = - cast<OperationInst>(builder.clone(*forInst))->cast<AffineForOp>(); + auto cleanupForInst = builder.clone(*forInst)->cast<AffineForOp>(); auto clLbMap = getCleanupLoopLowerBound(forOp, unrollFactor, &builder); assert(clLbMap && "cleanup loop lower bound map for single result bound maps can " diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp index 879a4f4b585..524e8d542f5 100644 --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -37,7 +37,7 @@ using namespace mlir; /// Return true if this operation dereferences one or more memref's. // Temporary utility: will be replaced when this is modeled through // side-effects/op traits. TODO(b/117228571) -static bool isMemRefDereferencingOp(const OperationInst &op) { +static bool isMemRefDereferencingOp(const Instruction &op) { if (op.isa<LoadOp>() || op.isa<StoreOp>() || op.isa<DmaStartOp>() || op.isa<DmaWaitOp>()) return true; @@ -76,12 +76,11 @@ bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef, std::make_unique<PostDominanceInfo>(postDomInstFilter->getFunction()); // The ops where memref replacement succeeds are replaced with new ones. - SmallVector<OperationInst *, 8> opsToErase; + SmallVector<Instruction *, 8> opsToErase; // Walk all uses of old memref. Operation using the memref gets replaced. - for (auto it = oldMemRef->use_begin(); it != oldMemRef->use_end();) { - InstOperand &use = *(it++); - auto *opInst = cast<OperationInst>(use.getOwner()); + for (auto &use : llvm::make_early_inc_range(oldMemRef->getUses())) { + auto *opInst = use.getOwner(); // Skip this use if it's not dominated by domInstFilter. if (domInstFilter && !domInfo->dominates(domInstFilter, opInst)) @@ -217,8 +216,7 @@ bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef, /// uses besides this opInst; otherwise returns the list of affine_apply /// operations created in output argument `sliceOps`. void mlir::createAffineComputationSlice( - OperationInst *opInst, - SmallVectorImpl<OpPointer<AffineApplyOp>> *sliceOps) { + Instruction *opInst, SmallVectorImpl<OpPointer<AffineApplyOp>> *sliceOps) { // Collect all operands that are results of affine apply ops. SmallVector<Value *, 4> subOperands; subOperands.reserve(opInst->getNumOperands()); @@ -230,7 +228,7 @@ void mlir::createAffineComputationSlice( } // Gather sequence of AffineApplyOps reachable from 'subOperands'. - SmallVector<OperationInst *, 4> affineApplyOps; + SmallVector<Instruction *, 4> affineApplyOps; getReachableAffineApplyOps(subOperands, affineApplyOps); // Skip transforming if there are no affine maps to compose. if (affineApplyOps.empty()) @@ -341,8 +339,7 @@ bool mlir::constantFoldBounds(OpPointer<AffineForOp> forInst) { } void mlir::remapFunctionAttrs( - OperationInst &op, - const DenseMap<Attribute, FunctionAttr> &remappingTable) { + Instruction &op, const DenseMap<Attribute, FunctionAttr> &remappingTable) { for (auto attr : op.getAttrs()) { // Do the remapping, if we got the same thing back, then it must contain // functions that aren't getting remapped. diff --git a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp index a9b9752ef51..7d51637a6e1 100644 --- a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp +++ b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp @@ -110,17 +110,13 @@ void VectorizerTestPass::testVectorShapeRatio(Function *f) { // Only filter instructions that operate on a strict super-vector and have one // return. This makes testing easier. auto filter = [subVectorType](const Instruction &inst) { - auto *opInst = dyn_cast<OperationInst>(&inst); - if (!opInst) { - return false; - } assert(subVectorType.getElementType() == Type::getF32(subVectorType.getContext()) && "Only f32 supported for now"); - if (!matcher::operatesOnSuperVectors(*opInst, subVectorType)) { + if (!matcher::operatesOnSuperVectors(inst, subVectorType)) { return false; } - if (opInst->getNumResults() != 1) { + if (inst.getNumResults() != 1) { return false; } return true; @@ -129,7 +125,7 @@ void VectorizerTestPass::testVectorShapeRatio(Function *f) { SmallVector<NestedMatch, 8> matches; pat.match(f, &matches); for (auto m : matches) { - auto *opInst = cast<OperationInst>(m.getMatchedInstruction()); + auto *opInst = m.getMatchedInstruction(); // This is a unit test that only checks and prints shape ratio. // As a consequence we write only Ops with a single return type for the // purpose of this test. If we need to test more intricate behavior in the @@ -159,8 +155,7 @@ static NestedPattern patternTestSlicingOps() { using matcher::Op; // Match all OpInstructions with the kTestSlicingOpName name. auto filter = [](const Instruction &inst) { - const auto &opInst = cast<OperationInst>(inst); - return opInst.getName().getStringRef() == kTestSlicingOpName; + return inst.getName().getStringRef() == kTestSlicingOpName; }; return Op(filter); } @@ -209,8 +204,7 @@ void VectorizerTestPass::testSlicing(Function *f) { } static bool customOpWithAffineMapAttribute(const Instruction &inst) { - const auto &opInst = cast<OperationInst>(inst); - return opInst.getName().getStringRef() == + return inst.getName().getStringRef() == VectorizerTestPass::kTestAffineMapOpName; } @@ -222,7 +216,7 @@ void VectorizerTestPass::testComposeMaps(Function *f) { SmallVector<AffineMap, 4> maps; maps.reserve(matches.size()); for (auto m : llvm::reverse(matches)) { - auto *opInst = cast<OperationInst>(m.getMatchedInstruction()); + auto *opInst = m.getMatchedInstruction(); auto map = opInst->getAttr(VectorizerTestPass::kTestAffineMapAttrName) .cast<AffineMapAttr>() .getValue(); @@ -236,13 +230,11 @@ void VectorizerTestPass::testComposeMaps(Function *f) { } static bool affineApplyOp(const Instruction &inst) { - const auto &opInst = cast<OperationInst>(inst); - return opInst.isa<AffineApplyOp>(); + return inst.isa<AffineApplyOp>(); } static bool singleResultAffineApplyOpWithoutUses(const Instruction &inst) { - const auto &opInst = cast<OperationInst>(inst); - auto app = opInst.dyn_cast<AffineApplyOp>(); + auto app = inst.dyn_cast<AffineApplyOp>(); return app && app->use_empty(); } @@ -259,8 +251,7 @@ void VectorizerTestPass::testNormalizeMaps(Function *f) { SmallVector<NestedMatch, 8> matches; pattern.match(f, &matches); for (auto m : matches) { - auto app = - cast<OperationInst>(m.getMatchedInstruction())->cast<AffineApplyOp>(); + auto app = m.getMatchedInstruction()->cast<AffineApplyOp>(); FuncBuilder b(m.getMatchedInstruction()); SmallVector<Value *, 8> operands(app->getOperands()); makeComposedAffineApply(&b, app->getLoc(), app->getAffineMap(), operands); diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index 661861dcfd4..5a8d5d24661 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -723,22 +723,22 @@ namespace { struct VectorizationState { /// Adds an entry of pre/post vectorization instructions in the state. - void registerReplacement(OperationInst *key, OperationInst *value); + void registerReplacement(Instruction *key, Instruction *value); /// When the current vectorization pattern is successful, this erases the /// instructions that were marked for erasure in the proper order and resets /// the internal state for the next pattern. void finishVectorizationPattern(); - // In-order tracking of original OperationInst that have been vectorized. + // In-order tracking of original Instruction that have been vectorized. // Erase in reverse order. - SmallVector<OperationInst *, 16> toErase; - // Set of OperationInst that have been vectorized (the values in the + SmallVector<Instruction *, 16> toErase; + // Set of Instruction that have been vectorized (the values in the // vectorizationMap for hashed access). The vectorizedSet is used in // particular to filter the instructions that have already been vectorized by // this pattern, when iterating over nested loops in this pattern. - DenseSet<OperationInst *> vectorizedSet; - // Map of old scalar OperationInst to new vectorized OperationInst. - DenseMap<OperationInst *, OperationInst *> vectorizationMap; + DenseSet<Instruction *> vectorizedSet; + // Map of old scalar Instruction to new vectorized Instruction. + DenseMap<Instruction *, Instruction *> vectorizationMap; // Map of old scalar Value to new vectorized Value. DenseMap<const Value *, Value *> replacementMap; // The strategy drives which loop to vectorize by which amount. @@ -747,17 +747,17 @@ struct VectorizationState { // vectorizeOperations function. They consist of the subset of load operations // that have been vectorized. They can be retrieved from `vectorizationMap` // but it is convenient to keep track of them in a separate data structure. - DenseSet<OperationInst *> roots; + DenseSet<Instruction *> roots; // Terminator instructions for the worklist in the vectorizeOperations // function. They consist of the subset of store operations that have been // vectorized. They can be retrieved from `vectorizationMap` but it is // convenient to keep track of them in a separate data structure. Since they // do not necessarily belong to use-def chains starting from loads (e.g // storing a constant), we need to handle them in a post-pass. - DenseSet<OperationInst *> terminators; + DenseSet<Instruction *> terminators; // Checks that the type of `inst` is StoreOp and adds it to the terminators // set. - void registerTerminator(OperationInst *inst); + void registerTerminator(Instruction *inst); private: void registerReplacement(const Value *key, Value *value); @@ -765,8 +765,8 @@ private: } // end namespace -void VectorizationState::registerReplacement(OperationInst *key, - OperationInst *value) { +void VectorizationState::registerReplacement(Instruction *key, + Instruction *value) { LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ commit vectorized op: "); LLVM_DEBUG(key->print(dbgs())); LLVM_DEBUG(dbgs() << " into "); @@ -785,7 +785,7 @@ void VectorizationState::registerReplacement(OperationInst *key, } } -void VectorizationState::registerTerminator(OperationInst *inst) { +void VectorizationState::registerTerminator(Instruction *inst) { assert(inst->isa<StoreOp>() && "terminator must be a StoreOp"); assert(terminators.count(inst) == 0 && "terminator was already inserted previously"); @@ -867,17 +867,16 @@ static bool vectorizeAffineForOp(AffineForOp *loop, int64_t step, if (!matcher::isLoadOrStore(inst)) { return false; } - auto *opInst = cast<OperationInst>(&inst); - return state->vectorizationMap.count(opInst) == 0 && - state->vectorizedSet.count(opInst) == 0 && - state->roots.count(opInst) == 0 && - state->terminators.count(opInst) == 0; + return state->vectorizationMap.count(&inst) == 0 && + state->vectorizedSet.count(&inst) == 0 && + state->roots.count(&inst) == 0 && + state->terminators.count(&inst) == 0; }; auto loadAndStores = matcher::Op(notVectorizedThisPattern); SmallVector<NestedMatch, 8> loadAndStoresMatches; loadAndStores.match(loop->getInstruction(), &loadAndStoresMatches); for (auto ls : loadAndStoresMatches) { - auto *opInst = cast<OperationInst>(ls.getMatchedInstruction()); + auto *opInst = ls.getMatchedInstruction(); auto load = opInst->dyn_cast<LoadOp>(); auto store = opInst->dyn_cast<StoreOp>(); LLVM_DEBUG(opInst->print(dbgs())); @@ -900,7 +899,7 @@ static bool vectorizeAffineForOp(AffineForOp *loop, int64_t step, static FilterFunctionType isVectorizableLoopPtrFactory(unsigned fastestVaryingMemRefDimension) { return [fastestVaryingMemRefDimension](const Instruction &forInst) { - auto loop = cast<OperationInst>(forInst).cast<AffineForOp>(); + auto loop = forInst.cast<AffineForOp>(); return isVectorizableLoopAlongFastestVaryingMemRefDim( loop, fastestVaryingMemRefDimension); }; @@ -915,7 +914,7 @@ static bool vectorizeNonRoot(ArrayRef<NestedMatch> matches, /// recursively in DFS post-order. static bool doVectorize(NestedMatch oneMatch, VectorizationState *state) { auto *loopInst = oneMatch.getMatchedInstruction(); - auto loop = cast<OperationInst>(loopInst)->cast<AffineForOp>(); + auto loop = loopInst->cast<AffineForOp>(); auto childrenMatches = oneMatch.getMatchedChildren(); // 1. DFS postorder recursion, if any of my children fails, I fail too. @@ -977,15 +976,14 @@ static Value *vectorizeConstant(Instruction *inst, const ConstantOp &constant, Location loc = inst->getLoc(); auto vectorType = type.cast<VectorType>(); auto attr = SplatElementsAttr::get(vectorType, constant.getValue()); - auto *constantOpInst = cast<OperationInst>(constant.getInstruction()); + auto *constantOpInst = constant.getInstruction(); OperationState state( b.getContext(), loc, constantOpInst->getName().getStringRef(), {}, {vectorType}, {make_pair(Identifier::get("value", b.getContext()), attr)}); - auto *splat = cast<OperationInst>(b.createOperation(state)); - return splat->getResult(0); + return b.createOperation(state)->getResult(0); } /// Returns a uniqu'ed VectorType. @@ -997,8 +995,7 @@ static Type getVectorType(Value *v, const VectorizationState &state) { if (!VectorType::isValidElementType(v->getType())) { return Type(); } - auto *definingOpInst = cast<OperationInst>(v->getDefiningInst()); - if (state.vectorizedSet.count(definingOpInst) > 0) { + if (state.vectorizedSet.count(v->getDefiningInst()) > 0) { return v->getType().cast<VectorType>(); } return VectorType::get(state.strategy->vectorSizes, v->getType()); @@ -1029,9 +1026,8 @@ static Value *vectorizeOperand(Value *operand, Instruction *inst, VectorizationState *state) { LLVM_DEBUG(dbgs() << "\n[early-vect]vectorize operand: "); LLVM_DEBUG(operand->print(dbgs())); - auto *definingInstruction = cast<OperationInst>(operand->getDefiningInst()); // 1. If this value has already been vectorized this round, we are done. - if (state->vectorizedSet.count(definingInstruction) > 0) { + if (state->vectorizedSet.count(operand->getDefiningInst()) > 0) { LLVM_DEBUG(dbgs() << " -> already vector operand"); return operand; } @@ -1062,7 +1058,7 @@ static Value *vectorizeOperand(Value *operand, Instruction *inst, return nullptr; }; -/// Encodes OperationInst-specific behavior for vectorization. In general we +/// Encodes Instruction-specific behavior for vectorization. In general we /// assume that all operands of an op must be vectorized but this is not always /// true. In the future, it would be nice to have a trait that describes how a /// particular operation vectorizes. For now we implement the case distinction @@ -1071,9 +1067,8 @@ static Value *vectorizeOperand(Value *operand, Instruction *inst, /// TODO(ntv): consider adding a trait to Op to describe how it gets vectorized. /// Maybe some Ops are not vectorizable or require some tricky logic, we cannot /// do one-off logic here; ideally it would be TableGen'd. -static OperationInst *vectorizeOneOperationInst(FuncBuilder *b, - OperationInst *opInst, - VectorizationState *state) { +static Instruction *vectorizeOneInstruction(FuncBuilder *b, Instruction *opInst, + VectorizationState *state) { // Sanity checks. assert(!opInst->isa<LoadOp>() && "all loads must have already been fully vectorized independently"); @@ -1094,7 +1089,7 @@ static OperationInst *vectorizeOneOperationInst(FuncBuilder *b, LLVM_DEBUG(permutationMap.print(dbgs())); auto transfer = b.create<VectorTransferWriteOp>( opInst->getLoc(), vectorValue, memRef, indices, permutationMap); - auto *res = cast<OperationInst>(transfer->getInstruction()); + auto *res = transfer->getInstruction(); LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ vectorized store: " << *res); // "Terminators" (i.e. StoreOps) are erased on the spot. opInst->erase(); @@ -1119,8 +1114,8 @@ static OperationInst *vectorizeOneOperationInst(FuncBuilder *b, // Create a clone of the op with the proper operands and return types. // TODO(ntv): The following assumes there is always an op with a fixed // name that works both in scalar mode and vector mode. - // TODO(ntv): Is it worth considering an OperationInst.clone operation - // which changes the type so we can promote an OperationInst with less + // TODO(ntv): Is it worth considering an Instruction.clone operation + // which changes the type so we can promote an Instruction with less // boilerplate? OperationState newOp(b->getContext(), opInst->getLoc(), opInst->getName().getStringRef(), operands, types, @@ -1129,22 +1124,22 @@ static OperationInst *vectorizeOneOperationInst(FuncBuilder *b, return b->createOperation(newOp); } -/// Iterates over the OperationInst in the loop and rewrites them using their +/// Iterates over the Instruction in the loop and rewrites them using their /// vectorized counterpart by: -/// 1. iteratively building a worklist of uses of the OperationInst vectorized +/// 1. iteratively building a worklist of uses of the Instruction vectorized /// so far by this pattern; -/// 2. for each OperationInst in the worklist, create the vector form of this +/// 2. for each Instruction in the worklist, create the vector form of this /// operation and replace all its uses by the vectorized form. For this step, /// the worklist must be traversed in order; /// 3. verify that all operands of the newly vectorized operation have been /// vectorized by this pattern. static bool vectorizeOperations(VectorizationState *state) { // 1. create initial worklist with the uses of the roots. - SetVector<OperationInst *> worklist; - auto insertUsesOf = [&worklist, state](OperationInst *vectorized) { + SetVector<Instruction *> worklist; + auto insertUsesOf = [&worklist, state](Instruction *vectorized) { for (auto *r : vectorized->getResults()) for (auto &u : r->getUses()) { - auto *inst = cast<OperationInst>(u.getOwner()); + auto *inst = u.getOwner(); // Don't propagate to terminals, a separate pass is needed for those. // TODO(ntv)[b/119759136]: use isa<> once Op is implemented. if (state->terminators.count(inst) > 0) { @@ -1166,7 +1161,7 @@ static bool vectorizeOperations(VectorizationState *state) { // 2. Create vectorized form of the instruction. // Insert it just before inst, on success register inst as replaced. FuncBuilder b(inst); - auto *vectorizedInst = vectorizeOneOperationInst(&b, inst, state); + auto *vectorizedInst = vectorizeOneInstruction(&b, inst, state); if (!vectorizedInst) { return true; } @@ -1179,7 +1174,7 @@ static bool vectorizeOperations(VectorizationState *state) { // 4. Augment the worklist with uses of the instruction we just vectorized. // This preserves the proper order in the worklist. - apply(insertUsesOf, ArrayRef<OperationInst *>{inst}); + apply(insertUsesOf, ArrayRef<Instruction *>{inst}); } return false; } @@ -1189,8 +1184,7 @@ static bool vectorizeOperations(VectorizationState *state) { /// Each root may succeed independently but will otherwise clean after itself if /// anything below it fails. static bool vectorizeRootMatch(NestedMatch m, VectorizationStrategy *strategy) { - auto loop = - cast<OperationInst>(m.getMatchedInstruction())->cast<AffineForOp>(); + auto loop = m.getMatchedInstruction()->cast<AffineForOp>(); VectorizationState state; state.strategy = strategy; @@ -1207,8 +1201,7 @@ static bool vectorizeRootMatch(NestedMatch m, VectorizationStrategy *strategy) { } auto *loopInst = loop->getInstruction(); FuncBuilder builder(loopInst); - auto clonedLoop = - cast<OperationInst>(builder.clone(*loopInst))->cast<AffineForOp>(); + auto clonedLoop = builder.clone(*loopInst)->cast<AffineForOp>(); auto fail = doVectorize(m, &state); /// Sets up error handling for this root loop. This is how the root match @@ -1248,12 +1241,12 @@ static bool vectorizeRootMatch(NestedMatch m, VectorizationStrategy *strategy) { } // Finally, vectorize the terminators. If anything fails to vectorize, skip. - auto vectorizeOrFail = [&fail, &state](OperationInst *inst) { + auto vectorizeOrFail = [&fail, &state](Instruction *inst) { if (fail) { return; } FuncBuilder b(inst); - auto *res = vectorizeOneOperationInst(&b, inst, &state); + auto *res = vectorizeOneInstruction(&b, inst, &state); if (res == nullptr) { fail = true; } |

