summaryrefslogtreecommitdiffstats
path: root/mlir
diff options
context:
space:
mode:
authorRiver Riddle <riverriddle@google.com>2019-02-04 10:38:47 -0800
committerjpienaar <jpienaar@google.com>2019-03-29 16:10:53 -0700
commitb499277fb648c44907443ce44ec6bcc6b7596039 (patch)
tree7d61826527e189c2952cbe5b2498b2c43ab6f839 /mlir
parent44e040dd635a8ce4b362cc81213f5e791b20830e (diff)
downloadbcm5719-llvm-b499277fb648c44907443ce44ec6bcc6b7596039.tar.gz
bcm5719-llvm-b499277fb648c44907443ce44ec6bcc6b7596039.zip
Remove remaining usages of OperationInst in lib/Transforms.
PiperOrigin-RevId: 232323671
Diffstat (limited to 'mlir')
-rw-r--r--mlir/lib/Transforms/CSE.cpp43
-rw-r--r--mlir/lib/Transforms/ComposeAffineMaps.cpp10
-rw-r--r--mlir/lib/Transforms/ConstantFold.cpp8
-rw-r--r--mlir/lib/Transforms/DialectConversion.cpp31
-rw-r--r--mlir/lib/Transforms/DmaGeneration.cpp5
-rw-r--r--mlir/lib/Transforms/LoopFusion.cpp134
-rw-r--r--mlir/lib/Transforms/LoopTiling.cpp5
-rw-r--r--mlir/lib/Transforms/LoopUnroll.cpp4
-rw-r--r--mlir/lib/Transforms/LoopUnrollAndJam.cpp10
-rw-r--r--mlir/lib/Transforms/LowerAffine.cpp12
-rw-r--r--mlir/lib/Transforms/LowerVectorTransfers.cpp5
-rw-r--r--mlir/lib/Transforms/MaterializeVectors.cpp44
-rw-r--r--mlir/lib/Transforms/MemRefDataFlowOpt.cpp30
-rw-r--r--mlir/lib/Transforms/PipelineDataTransfer.cpp43
-rw-r--r--mlir/lib/Transforms/SimplifyAffineStructures.cpp2
-rw-r--r--mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp38
-rw-r--r--mlir/lib/Transforms/Utils/LoopUtils.cpp9
-rw-r--r--mlir/lib/Transforms/Utils/Utils.cpp17
-rw-r--r--mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp27
-rw-r--r--mlir/lib/Transforms/Vectorize.cpp91
20 files changed, 251 insertions, 317 deletions
diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp
index e471b6792c5..63a676d7b52 100644
--- a/mlir/lib/Transforms/CSE.cpp
+++ b/mlir/lib/Transforms/CSE.cpp
@@ -39,10 +39,10 @@ using namespace mlir;
namespace {
// TODO(riverriddle) Handle commutative operations.
-struct SimpleOperationInfo : public llvm::DenseMapInfo<OperationInst *> {
- static unsigned getHashValue(const OperationInst *op) {
+struct SimpleOperationInfo : public llvm::DenseMapInfo<Instruction *> {
+ static unsigned getHashValue(const Instruction *op) {
// Hash the operations based upon their:
- // - OperationInst Name
+ // - Instruction Name
// - Attributes
// - Result Types
// - Operands
@@ -51,7 +51,7 @@ struct SimpleOperationInfo : public llvm::DenseMapInfo<OperationInst *> {
hash_combine_range(op->result_type_begin(), op->result_type_end()),
hash_combine_range(op->operand_begin(), op->operand_end()));
}
- static bool isEqual(const OperationInst *lhs, const OperationInst *rhs) {
+ static bool isEqual(const Instruction *lhs, const Instruction *rhs) {
if (lhs == rhs)
return true;
if (lhs == getTombstoneKey() || lhs == getEmptyKey() ||
@@ -89,8 +89,8 @@ struct CSE : public FunctionPass {
/// Shared implementation of operation elimination and scoped map definitions.
using AllocatorTy = llvm::RecyclingAllocator<
llvm::BumpPtrAllocator,
- llvm::ScopedHashTableVal<OperationInst *, OperationInst *>>;
- using ScopedMapTy = llvm::ScopedHashTable<OperationInst *, OperationInst *,
+ llvm::ScopedHashTableVal<Instruction *, Instruction *>>;
+ using ScopedMapTy = llvm::ScopedHashTable<Instruction *, Instruction *,
SimpleOperationInfo, AllocatorTy>;
/// Represents a single entry in the depth first traversal of a CFG.
@@ -111,7 +111,7 @@ struct CSE : public FunctionPass {
/// Attempt to eliminate a redundant operation. Returns true if the operation
/// was marked for removal, false otherwise.
- bool simplifyOperation(OperationInst *op);
+ bool simplifyOperation(Instruction *op);
void simplifyBlock(Block *bb);
@@ -122,14 +122,14 @@ private:
ScopedMapTy knownValues;
/// Operations marked as dead and to be erased.
- std::vector<OperationInst *> opsToErase;
+ std::vector<Instruction *> opsToErase;
};
} // end anonymous namespace
char CSE::passID = 0;
/// Attempt to eliminate a redundant operation.
-bool CSE::simplifyOperation(OperationInst *op) {
+bool CSE::simplifyOperation(Instruction *op) {
// TODO(riverriddle) We currently only eliminate non side-effecting
// operations.
if (!op->hasNoSideEffect())
@@ -166,23 +166,16 @@ bool CSE::simplifyOperation(OperationInst *op) {
void CSE::simplifyBlock(Block *bb) {
for (auto &i : *bb) {
- switch (i.getKind()) {
- case Instruction::Kind::OperationInst: {
- auto *opInst = cast<OperationInst>(&i);
-
- // If the operation is simplified, we don't process any held block lists.
- if (simplifyOperation(opInst))
- continue;
-
- // Simplify any held blocks.
- for (auto &blockList : opInst->getBlockLists()) {
- for (auto &b : blockList) {
- ScopedMapTy::ScopeTy scope(knownValues);
- simplifyBlock(&b);
- }
+ // If the operation is simplified, we don't process any held block lists.
+ if (simplifyOperation(&i))
+ continue;
+
+ // Simplify any held blocks.
+ for (auto &blockList : i.getBlockLists()) {
+ for (auto &b : blockList) {
+ ScopedMapTy::ScopeTy scope(knownValues);
+ simplifyBlock(&b);
}
- break;
- }
}
}
}
diff --git a/mlir/lib/Transforms/ComposeAffineMaps.cpp b/mlir/lib/Transforms/ComposeAffineMaps.cpp
index 4f960ea73af..4a6430dc9be 100644
--- a/mlir/lib/Transforms/ComposeAffineMaps.cpp
+++ b/mlir/lib/Transforms/ComposeAffineMaps.cpp
@@ -48,7 +48,7 @@ namespace {
struct ComposeAffineMaps : public FunctionPass, InstWalker<ComposeAffineMaps> {
explicit ComposeAffineMaps() : FunctionPass(&ComposeAffineMaps::passID) {}
PassResult runOnFunction(Function *f) override;
- void visitInstruction(OperationInst *opInst);
+ void visitInstruction(Instruction *opInst);
SmallVector<OpPointer<AffineApplyOp>, 8> affineApplyOps;
@@ -64,14 +64,12 @@ FunctionPass *mlir::createComposeAffineMapsPass() {
}
static bool affineApplyOp(const Instruction &inst) {
- const auto &opInst = cast<OperationInst>(inst);
- return opInst.isa<AffineApplyOp>();
+ return inst.isa<AffineApplyOp>();
}
-void ComposeAffineMaps::visitInstruction(OperationInst *opInst) {
- if (auto afOp = opInst->dyn_cast<AffineApplyOp>()) {
+void ComposeAffineMaps::visitInstruction(Instruction *opInst) {
+ if (auto afOp = opInst->dyn_cast<AffineApplyOp>())
affineApplyOps.push_back(afOp);
- }
}
PassResult ComposeAffineMaps::runOnFunction(Function *f) {
diff --git a/mlir/lib/Transforms/ConstantFold.cpp b/mlir/lib/Transforms/ConstantFold.cpp
index 859d0012fac..54486cdb293 100644
--- a/mlir/lib/Transforms/ConstantFold.cpp
+++ b/mlir/lib/Transforms/ConstantFold.cpp
@@ -33,11 +33,11 @@ struct ConstantFold : public FunctionPass, InstWalker<ConstantFold> {
// All constants in the function post folding.
SmallVector<Value *, 8> existingConstants;
// Operations that were folded and that need to be erased.
- std::vector<OperationInst *> opInstsToErase;
+ std::vector<Instruction *> opInstsToErase;
- bool foldOperation(OperationInst *op,
+ bool foldOperation(Instruction *op,
SmallVectorImpl<Value *> &existingConstants);
- void visitInstruction(OperationInst *op);
+ void visitInstruction(Instruction *op);
PassResult runOnFunction(Function *f) override;
static char passID;
@@ -49,7 +49,7 @@ char ConstantFold::passID = 0;
/// Attempt to fold the specified operation, updating the IR to match. If
/// constants are found, we keep track of them in the existingConstants list.
///
-void ConstantFold::visitInstruction(OperationInst *op) {
+void ConstantFold::visitInstruction(Instruction *op) {
// If this operation is an AffineForOp, then fold the bounds.
if (auto forOp = op->dyn_cast<AffineForOp>()) {
constantFoldBounds(forOp);
diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp
index 443e7750947..996416d9271 100644
--- a/mlir/lib/Transforms/DialectConversion.cpp
+++ b/mlir/lib/Transforms/DialectConversion.cpp
@@ -50,7 +50,7 @@ private:
// Utility that looks up a list of value in the value remapping table. Returns
// an empty vector if one of the values is not mapped yet.
SmallVector<Value *, 4>
- lookupValues(const llvm::iterator_range<OperationInst::const_operand_iterator>
+ lookupValues(const llvm::iterator_range<Instruction::const_operand_iterator>
&operands);
// Converts the given function to the dialect using hooks defined in
@@ -61,13 +61,13 @@ private:
// from `valueRemapping` and the converted blocks from `blockRemapping`, and
// passes them to `converter->rewriteTerminator` function defined in the
// pattern, together with `builder`.
- bool convertOpWithSuccessors(DialectOpConversion *converter,
- OperationInst *op, FuncBuilder &builder);
+ bool convertOpWithSuccessors(DialectOpConversion *converter, Instruction *op,
+ FuncBuilder &builder);
// Converts an operation without successors. Extracts the converted operands
// from `valueRemapping` and passes them to the `converter->rewrite` function
// defined in the pattern, together with `builder`.
- bool convertOp(DialectOpConversion *converter, OperationInst *op,
+ bool convertOp(DialectOpConversion *converter, Instruction *op,
FuncBuilder &builder);
// Converts a block by traversing its instructions sequentially, looking for
@@ -104,8 +104,7 @@ private:
} // end namespace mlir
SmallVector<Value *, 4> impl::FunctionConversion::lookupValues(
- const llvm::iterator_range<OperationInst::const_operand_iterator>
- &operands) {
+ const llvm::iterator_range<Instruction::const_operand_iterator> &operands) {
SmallVector<Value *, 4> remapped;
remapped.reserve(llvm::size(operands));
for (const Value *operand : operands) {
@@ -118,7 +117,7 @@ SmallVector<Value *, 4> impl::FunctionConversion::lookupValues(
}
bool impl::FunctionConversion::convertOpWithSuccessors(
- DialectOpConversion *converter, OperationInst *op, FuncBuilder &builder) {
+ DialectOpConversion *converter, Instruction *op, FuncBuilder &builder) {
SmallVector<Block *, 2> destinations;
destinations.reserve(op->getNumSuccessors());
SmallVector<Value *, 4> operands = lookupValues(op->getOperands());
@@ -149,7 +148,7 @@ bool impl::FunctionConversion::convertOpWithSuccessors(
}
bool impl::FunctionConversion::convertOp(DialectOpConversion *converter,
- OperationInst *op,
+ Instruction *op,
FuncBuilder &builder) {
auto operands = lookupValues(op->getOperands());
assert((!operands.empty() || op->getNumOperands() == 0) &&
@@ -174,24 +173,22 @@ bool impl::FunctionConversion::convertBlock(
// Iterate over ops and convert them.
for (Instruction &inst : *block) {
- auto op = dyn_cast<OperationInst>(&inst);
- if (!op) {
- inst.emitError("unsupported instruction (For/If)");
+ if (inst.getNumBlockLists() != 0) {
+ inst.emitError("unsupported region instruction");
return true;
}
// Find the first matching conversion and apply it.
bool converted = false;
for (auto *conversion : conversions) {
- if (!conversion->match(op))
+ if (!conversion->match(&inst))
continue;
- if (op->isTerminator() && op->getNumSuccessors() > 0) {
- if (convertOpWithSuccessors(conversion, op, builder))
- return true;
- } else {
- if (convertOp(conversion, op, builder))
+ if (inst.isTerminator() && inst.getNumSuccessors() > 0) {
+ if (convertOpWithSuccessors(conversion, &inst, builder))
return true;
+ } else if (convertOp(conversion, &inst, builder)) {
+ return true;
}
converted = true;
break;
diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp
index 2bbb32036c2..92ae3767098 100644
--- a/mlir/lib/Transforms/DmaGeneration.cpp
+++ b/mlir/lib/Transforms/DmaGeneration.cpp
@@ -157,8 +157,7 @@ static void getMultiLevelStrides(const MemRefRegion &region,
/// dynamic shaped memref's for now. `numParamLoopIVs` is the number of
/// enclosing loop IVs of opInst (starting from the outermost) that the region
/// is parametric on.
-static bool getFullMemRefAsRegion(OperationInst *opInst,
- unsigned numParamLoopIVs,
+static bool getFullMemRefAsRegion(Instruction *opInst, unsigned numParamLoopIVs,
MemRefRegion *region) {
unsigned rank;
if (auto loadOp = opInst->dyn_cast<LoadOp>()) {
@@ -563,7 +562,7 @@ uint64_t DmaGeneration::runOnBlock(Block::iterator begin, Block::iterator end) {
fastBufferMap.clear();
// Walk this range of instructions to gather all memory regions.
- block->walk(begin, end, [&](OperationInst *opInst) {
+ block->walk(begin, end, [&](Instruction *opInst) {
// Gather regions to allocate to buffers in faster memory space.
if (auto loadOp = opInst->dyn_cast<LoadOp>()) {
if (loadOp->getMemRefType().getMemorySpace() != slowMemorySpace)
diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp
index 304331320ac..d7d69e569e5 100644
--- a/mlir/lib/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Transforms/LoopFusion.cpp
@@ -114,11 +114,11 @@ namespace {
class LoopNestStateCollector : public InstWalker<LoopNestStateCollector> {
public:
SmallVector<OpPointer<AffineForOp>, 4> forOps;
- SmallVector<OperationInst *, 4> loadOpInsts;
- SmallVector<OperationInst *, 4> storeOpInsts;
+ SmallVector<Instruction *, 4> loadOpInsts;
+ SmallVector<Instruction *, 4> storeOpInsts;
bool hasNonForRegion = false;
- void visitInstruction(OperationInst *opInst) {
+ void visitInstruction(Instruction *opInst) {
if (opInst->isa<AffineForOp>())
forOps.push_back(opInst->cast<AffineForOp>());
else if (opInst->getNumBlockLists() != 0)
@@ -131,7 +131,7 @@ public:
};
// TODO(b/117228571) Replace when this is modeled through side-effects/op traits
-static bool isMemRefDereferencingOp(const OperationInst &op) {
+static bool isMemRefDereferencingOp(const Instruction &op) {
if (op.isa<LoadOp>() || op.isa<StoreOp>() || op.isa<DmaStartOp>() ||
op.isa<DmaWaitOp>())
return true;
@@ -153,9 +153,9 @@ public:
// The top-level statment which is (or contains) loads/stores.
Instruction *inst;
// List of load operations.
- SmallVector<OperationInst *, 4> loads;
+ SmallVector<Instruction *, 4> loads;
// List of store op insts.
- SmallVector<OperationInst *, 4> stores;
+ SmallVector<Instruction *, 4> stores;
Node(unsigned id, Instruction *inst) : id(id), inst(inst) {}
// Returns the load op count for 'memref'.
@@ -258,16 +258,13 @@ public:
for (auto *storeOpInst : node->stores) {
auto *memref = storeOpInst->cast<StoreOp>()->getMemRef();
auto *inst = memref->getDefiningInst();
- auto *opInst = dyn_cast_or_null<OperationInst>(inst);
- // Return false if 'memref' is a function argument.
- if (opInst == nullptr)
+ // Return false if 'memref' is a block argument.
+ if (!inst)
return true;
// Return false if any use of 'memref' escapes the function.
- for (auto &use : memref->getUses()) {
- auto *user = dyn_cast<OperationInst>(use.getOwner());
- if (!user || !isMemRefDereferencingOp(*user))
+ for (auto &use : memref->getUses())
+ if (!isMemRefDereferencingOp(*use.getOwner()))
return true;
- }
}
return false;
}
@@ -461,8 +458,8 @@ public:
}
// Adds ops in 'loads' and 'stores' to node at 'id'.
- void addToNode(unsigned id, const SmallVectorImpl<OperationInst *> &loads,
- const SmallVectorImpl<OperationInst *> &stores) {
+ void addToNode(unsigned id, const SmallVectorImpl<Instruction *> &loads,
+ const SmallVectorImpl<Instruction *> &stores) {
Node *node = getNode(id);
for (auto *loadOpInst : loads)
node->loads.push_back(loadOpInst);
@@ -509,7 +506,7 @@ bool MemRefDependenceGraph::init(Function *f) {
DenseMap<Instruction *, unsigned> forToNodeMap;
for (auto &inst : f->front()) {
- if (auto forOp = cast<OperationInst>(&inst)->dyn_cast<AffineForOp>()) {
+ if (auto forOp = inst.dyn_cast<AffineForOp>()) {
// Create graph node 'id' to represent top-level 'forOp' and record
// all loads and store accesses it contains.
LoopNestStateCollector collector;
@@ -530,30 +527,28 @@ bool MemRefDependenceGraph::init(Function *f) {
}
forToNodeMap[&inst] = node.id;
nodes.insert({node.id, node});
- } else if (auto *opInst = dyn_cast<OperationInst>(&inst)) {
- if (auto loadOp = opInst->dyn_cast<LoadOp>()) {
- // Create graph node for top-level load op.
- Node node(nextNodeId++, &inst);
- node.loads.push_back(opInst);
- auto *memref = opInst->cast<LoadOp>()->getMemRef();
- memrefAccesses[memref].insert(node.id);
- nodes.insert({node.id, node});
- } else if (auto storeOp = opInst->dyn_cast<StoreOp>()) {
- // Create graph node for top-level store op.
- Node node(nextNodeId++, &inst);
- node.stores.push_back(opInst);
- auto *memref = opInst->cast<StoreOp>()->getMemRef();
- memrefAccesses[memref].insert(node.id);
- nodes.insert({node.id, node});
- } else if (opInst->getNumBlockLists() != 0) {
- // Return false if another region is found (not currently supported).
- return false;
- } else if (opInst->getNumResults() > 0 && !opInst->use_empty()) {
- // Create graph node for top-level producer of SSA values, which
- // could be used by loop nest nodes.
- Node node(nextNodeId++, &inst);
- nodes.insert({node.id, node});
- }
+ } else if (auto loadOp = inst.dyn_cast<LoadOp>()) {
+ // Create graph node for top-level load op.
+ Node node(nextNodeId++, &inst);
+ node.loads.push_back(&inst);
+ auto *memref = inst.cast<LoadOp>()->getMemRef();
+ memrefAccesses[memref].insert(node.id);
+ nodes.insert({node.id, node});
+ } else if (auto storeOp = inst.dyn_cast<StoreOp>()) {
+ // Create graph node for top-level store op.
+ Node node(nextNodeId++, &inst);
+ node.stores.push_back(&inst);
+ auto *memref = inst.cast<StoreOp>()->getMemRef();
+ memrefAccesses[memref].insert(node.id);
+ nodes.insert({node.id, node});
+ } else if (inst.getNumBlockLists() != 0) {
+ // Return false if another region is found (not currently supported).
+ return false;
+ } else if (inst.getNumResults() > 0 && !inst.use_empty()) {
+ // Create graph node for top-level producer of SSA values, which
+ // could be used by loop nest nodes.
+ Node node(nextNodeId++, &inst);
+ nodes.insert({node.id, node});
}
}
@@ -563,12 +558,11 @@ bool MemRefDependenceGraph::init(Function *f) {
const Node &node = idAndNode.second;
if (!node.loads.empty() || !node.stores.empty())
continue;
- auto *opInst = cast<OperationInst>(node.inst);
+ auto *opInst = node.inst;
for (auto *value : opInst->getResults()) {
for (auto &use : value->getUses()) {
- auto *userOpInst = cast<OperationInst>(use.getOwner());
SmallVector<OpPointer<AffineForOp>, 4> loops;
- getLoopIVs(*userOpInst, &loops);
+ getLoopIVs(*use.getOwner(), &loops);
if (loops.empty())
continue;
assert(forToNodeMap.count(loops[0]->getInstruction()) > 0);
@@ -619,7 +613,7 @@ public:
LoopNestStatsCollector(LoopNestStats *stats) : stats(stats) {}
- void visitInstruction(OperationInst *opInst) {
+ void visitInstruction(Instruction *opInst) {
auto forOp = opInst->dyn_cast<AffineForOp>();
if (!forOp)
return;
@@ -627,8 +621,7 @@ public:
auto *forInst = forOp->getInstruction();
auto *parentInst = forOp->getInstruction()->getParentInst();
if (parentInst != nullptr) {
- assert(cast<OperationInst>(parentInst)->isa<AffineForOp>() &&
- "Expected parent AffineForOp");
+ assert(parentInst->isa<AffineForOp>() && "Expected parent AffineForOp");
// Add mapping to 'forOp' from its parent AffineForOp.
stats->loopMap[parentInst].push_back(forOp);
}
@@ -637,8 +630,7 @@ public:
unsigned count = 0;
stats->opCountMap[forInst] = 0;
for (auto &inst : *forOp->getBody()) {
- if (!(cast<OperationInst>(inst).isa<AffineForOp>() ||
- cast<OperationInst>(inst).isa<AffineIfOp>()))
+ if (!(inst.isa<AffineForOp>() || inst.isa<AffineIfOp>()))
++count;
}
stats->opCountMap[forInst] = count;
@@ -723,7 +715,7 @@ static Optional<uint64_t> getConstDifference(AffineMap lbMap, AffineMap ubMap) {
// was encountered).
// TODO(andydavis) Make this work with non-unit step loops.
static bool buildSliceTripCountMap(
- OperationInst *srcOpInst, ComputationSliceState *sliceState,
+ Instruction *srcOpInst, ComputationSliceState *sliceState,
llvm::SmallDenseMap<Instruction *, uint64_t, 8> *tripCountMap) {
SmallVector<OpPointer<AffineForOp>, 4> srcLoopIVs;
getLoopIVs(*srcOpInst, &srcLoopIVs);
@@ -755,10 +747,10 @@ static bool buildSliceTripCountMap(
// adds them to 'dstLoads'.
static void
moveLoadsAccessingMemrefTo(Value *memref,
- SmallVectorImpl<OperationInst *> *srcLoads,
- SmallVectorImpl<OperationInst *> *dstLoads) {
+ SmallVectorImpl<Instruction *> *srcLoads,
+ SmallVectorImpl<Instruction *> *dstLoads) {
dstLoads->clear();
- SmallVector<OperationInst *, 4> srcLoadsToKeep;
+ SmallVector<Instruction *, 4> srcLoadsToKeep;
for (auto *load : *srcLoads) {
if (load->cast<LoadOp>()->getMemRef() == memref)
dstLoads->push_back(load);
@@ -769,7 +761,7 @@ moveLoadsAccessingMemrefTo(Value *memref,
}
// Returns the innermost common loop depth for the set of operations in 'ops'.
-static unsigned getInnermostCommonLoopDepth(ArrayRef<OperationInst *> ops) {
+static unsigned getInnermostCommonLoopDepth(ArrayRef<Instruction *> ops) {
unsigned numOps = ops.size();
assert(numOps > 0);
@@ -797,10 +789,10 @@ static unsigned getInnermostCommonLoopDepth(ArrayRef<OperationInst *> ops) {
// Returns the maximum loop depth at which no dependences between 'loadOpInsts'
// and 'storeOpInsts' are satisfied.
-static unsigned getMaxLoopDepth(ArrayRef<OperationInst *> loadOpInsts,
- ArrayRef<OperationInst *> storeOpInsts) {
+static unsigned getMaxLoopDepth(ArrayRef<Instruction *> loadOpInsts,
+ ArrayRef<Instruction *> storeOpInsts) {
// Merge loads and stores into the same array.
- SmallVector<OperationInst *, 2> ops(loadOpInsts.begin(), loadOpInsts.end());
+ SmallVector<Instruction *, 2> ops(loadOpInsts.begin(), loadOpInsts.end());
ops.append(storeOpInsts.begin(), storeOpInsts.end());
// Compute the innermost common loop depth for loads and stores.
@@ -913,7 +905,7 @@ unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
// TODO(bondhugula): consider refactoring the common code from generateDma and
// this one.
static Value *createPrivateMemRef(OpPointer<AffineForOp> forOp,
- OperationInst *srcStoreOpInst,
+ Instruction *srcStoreOpInst,
unsigned dstLoopDepth,
Optional<unsigned> fastMemorySpace,
unsigned localBufSizeThreshold) {
@@ -1061,9 +1053,9 @@ static uint64_t getSliceIterationCount(
// *) Compares the total cost of the unfused loop nests to the min cost fused
// loop nest computed in the previous step, and returns true if the latter
// is lower.
-static bool isFusionProfitable(OperationInst *srcOpInst,
- ArrayRef<OperationInst *> dstLoadOpInsts,
- ArrayRef<OperationInst *> dstStoreOpInsts,
+static bool isFusionProfitable(Instruction *srcOpInst,
+ ArrayRef<Instruction *> dstLoadOpInsts,
+ ArrayRef<Instruction *> dstStoreOpInsts,
ComputationSliceState *sliceState,
unsigned *dstLoopDepth) {
LLVM_DEBUG({
@@ -1174,7 +1166,7 @@ static bool isFusionProfitable(OperationInst *srcOpInst,
computeCostMap[srcLoopIVs[numSrcLoopIVs - 1]->getInstruction()] = -1;
for (auto *loadOp : dstLoadOpInsts) {
auto *parentInst = loadOp->getParentInst();
- if (parentInst && cast<OperationInst>(parentInst)->isa<AffineForOp>())
+ if (parentInst && parentInst->isa<AffineForOp>())
computeCostMap[parentInst] = -1;
}
}
@@ -1393,11 +1385,11 @@ public:
// Get 'dstNode' into which to attempt fusion.
auto *dstNode = mdg->getNode(dstId);
// Skip if 'dstNode' is not a loop nest.
- if (!cast<OperationInst>(dstNode->inst)->isa<AffineForOp>())
+ if (!dstNode->inst->isa<AffineForOp>())
continue;
- SmallVector<OperationInst *, 4> loads = dstNode->loads;
- SmallVector<OperationInst *, 4> dstLoadOpInsts;
+ SmallVector<Instruction *, 4> loads = dstNode->loads;
+ SmallVector<Instruction *, 4> dstLoadOpInsts;
DenseSet<Value *> visitedMemrefs;
while (!loads.empty()) {
// Get memref of load on top of the stack.
@@ -1426,7 +1418,7 @@ public:
// Get 'srcNode' from which to attempt fusion into 'dstNode'.
auto *srcNode = mdg->getNode(srcId);
// Skip if 'srcNode' is not a loop nest.
- if (!cast<OperationInst>(srcNode->inst)->isa<AffineForOp>())
+ if (!srcNode->inst->isa<AffineForOp>())
continue;
// Skip if 'srcNode' has more than one store to any memref.
// TODO(andydavis) Support fusing multi-output src loop nests.
@@ -1454,7 +1446,7 @@ public:
// Get unique 'srcNode' store op.
auto *srcStoreOpInst = srcNode->stores.front();
// Gather 'dstNode' store ops to 'memref'.
- SmallVector<OperationInst *, 2> dstStoreOpInsts;
+ SmallVector<Instruction *, 2> dstStoreOpInsts;
for (auto *storeOpInst : dstNode->stores)
if (storeOpInst->cast<StoreOp>()->getMemRef() == memref)
dstStoreOpInsts.push_back(storeOpInst);
@@ -1472,8 +1464,7 @@ public:
srcStoreOpInst, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState);
if (sliceLoopNest != nullptr) {
// Move 'dstAffineForOp' before 'insertPointInst' if needed.
- auto dstAffineForOp =
- cast<OperationInst>(dstNode->inst)->cast<AffineForOp>();
+ auto dstAffineForOp = dstNode->inst->cast<AffineForOp>();
if (insertPointInst != dstAffineForOp->getInstruction()) {
dstAffineForOp->getInstruction()->moveBefore(insertPointInst);
}
@@ -1488,7 +1479,7 @@ public:
promoteIfSingleIteration(forOp);
}
// Create private memref for 'memref' in 'dstAffineForOp'.
- SmallVector<OperationInst *, 4> storesForMemref;
+ SmallVector<Instruction *, 4> storesForMemref;
for (auto *storeOpInst : sliceCollector.storeOpInsts) {
if (storeOpInst->cast<StoreOp>()->getMemRef() == memref)
storesForMemref.push_back(storeOpInst);
@@ -1541,9 +1532,8 @@ public:
continue;
// Use list expected to match the dep graph info.
auto *inst = memref->getDefiningInst();
- auto *opInst = dyn_cast_or_null<OperationInst>(inst);
- if (opInst && opInst->isa<AllocOp>())
- opInst->erase();
+ if (inst && inst->isa<AllocOp>())
+ inst->erase();
}
}
};
diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp
index f1ee7fd1853..8b368e5f182 100644
--- a/mlir/lib/Transforms/LoopTiling.cpp
+++ b/mlir/lib/Transforms/LoopTiling.cpp
@@ -237,14 +237,13 @@ getTileableBands(Function *f,
do {
band.push_back(currInst);
} while (currInst->getBody()->getInstructions().size() == 1 &&
- (currInst = cast<OperationInst>(currInst->getBody()->front())
- .dyn_cast<AffineForOp>()));
+ (currInst = currInst->getBody()->front().dyn_cast<AffineForOp>()));
bands->push_back(band);
};
for (auto &block : *f)
for (auto &inst : block)
- if (auto forOp = cast<OperationInst>(inst).dyn_cast<AffineForOp>())
+ if (auto forOp = inst.dyn_cast<AffineForOp>())
getMaximalPerfectLoopNest(forOp);
}
diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp
index 9c9952d31ca..b1e15ccb07b 100644
--- a/mlir/lib/Transforms/LoopUnroll.cpp
+++ b/mlir/lib/Transforms/LoopUnroll.cpp
@@ -113,7 +113,7 @@ PassResult LoopUnroll::runOnFunction(Function *f) {
return hasInnerLoops;
}
- bool walkPostOrder(OperationInst *opInst) {
+ bool walkPostOrder(Instruction *opInst) {
bool hasInnerLoops = false;
for (auto &blockList : opInst->getBlockLists())
for (auto &block : blockList)
@@ -140,7 +140,7 @@ PassResult LoopUnroll::runOnFunction(Function *f) {
const unsigned minTripCount;
ShortLoopGatherer(unsigned minTripCount) : minTripCount(minTripCount) {}
- void visitInstruction(OperationInst *opInst) {
+ void visitInstruction(Instruction *opInst) {
auto forOp = opInst->dyn_cast<AffineForOp>();
if (!forOp)
return;
diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp
index d87f9d5dc14..74c54fde047 100644
--- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp
+++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp
@@ -100,8 +100,7 @@ PassResult LoopUnrollAndJam::runOnFunction(Function *f) {
// any for Inst.
auto &entryBlock = f->front();
if (!entryBlock.empty())
- if (auto forOp =
- cast<OperationInst>(entryBlock.front()).dyn_cast<AffineForOp>())
+ if (auto forOp = entryBlock.front().dyn_cast<AffineForOp>())
runOnAffineForOp(forOp);
return success();
@@ -149,12 +148,12 @@ bool mlir::loopUnrollJamByFactor(OpPointer<AffineForOp> forOp,
void walk(InstListType::iterator Start, InstListType::iterator End) {
for (auto it = Start; it != End;) {
auto subBlockStart = it;
- while (it != End && !cast<OperationInst>(it)->isa<AffineForOp>())
+ while (it != End && !it->isa<AffineForOp>())
++it;
if (it != subBlockStart)
subBlocks.push_back({subBlockStart, std::prev(it)});
// Process all for insts that appear next.
- while (it != End && cast<OperationInst>(it)->isa<AffineForOp>())
+ while (it != End && it->isa<AffineForOp>())
walk(&*it++);
}
}
@@ -206,8 +205,7 @@ bool mlir::loopUnrollJamByFactor(OpPointer<AffineForOp> forOp,
// Insert the cleanup loop right after 'forOp'.
FuncBuilder builder(forInst->getBlock(),
std::next(Block::iterator(forInst)));
- auto cleanupAffineForOp =
- cast<OperationInst>(builder.clone(*forInst))->cast<AffineForOp>();
+ auto cleanupAffineForOp = builder.clone(*forInst)->cast<AffineForOp>();
cleanupAffineForOp->setLowerBoundMap(
getCleanupLoopLowerBound(forOp, unrollJamFactor, &builder));
diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp
index 08c8188fada..88ccc90c18b 100644
--- a/mlir/lib/Transforms/LowerAffine.cpp
+++ b/mlir/lib/Transforms/LowerAffine.cpp
@@ -616,23 +616,21 @@ PassResult LowerAffinePass::runOnFunction(Function *function) {
// Collect all the For instructions as well as AffineIfOps and AffineApplyOps.
// We do this as a prepass to avoid invalidating the walker with our rewrite.
function->walk([&](Instruction *inst) {
- auto op = cast<OperationInst>(inst);
- if (op->isa<AffineApplyOp>() || op->isa<AffineForOp>() ||
- op->isa<AffineIfOp>())
+ if (inst->isa<AffineApplyOp>() || inst->isa<AffineForOp>() ||
+ inst->isa<AffineIfOp>())
instsToRewrite.push_back(inst);
});
// Rewrite all of the ifs and fors. We walked the instructions in preorder,
// so we know that we will rewrite them in the same order.
for (auto *inst : instsToRewrite) {
- auto op = cast<OperationInst>(inst);
- if (auto ifOp = op->dyn_cast<AffineIfOp>()) {
+ if (auto ifOp = inst->dyn_cast<AffineIfOp>()) {
if (lowerAffineIf(ifOp))
return failure();
- } else if (auto forOp = op->dyn_cast<AffineForOp>()) {
+ } else if (auto forOp = inst->dyn_cast<AffineForOp>()) {
if (lowerAffineFor(forOp))
return failure();
- } else if (lowerAffineApply(op->cast<AffineApplyOp>())) {
+ } else if (lowerAffineApply(inst->cast<AffineApplyOp>())) {
return failure();
}
}
diff --git a/mlir/lib/Transforms/LowerVectorTransfers.cpp b/mlir/lib/Transforms/LowerVectorTransfers.cpp
index 7f1e9b157d8..63fb45db9c5 100644
--- a/mlir/lib/Transforms/LowerVectorTransfers.cpp
+++ b/mlir/lib/Transforms/LowerVectorTransfers.cpp
@@ -401,13 +401,12 @@ public:
explicit VectorTransferExpander(MLIRContext *context)
: MLLoweringPattern(VectorTransferOpTy::getOperationName(), 1, context) {}
- PatternMatchResult match(OperationInst *op) const override {
+ PatternMatchResult match(Instruction *op) const override {
if (m_Op<VectorTransferOpTy>().match(op))
return matchSuccess();
return matchFailure();
}
- void rewriteOpInst(OperationInst *op,
- MLFuncGlobalLoweringState *funcWiseState,
+ void rewriteOpInst(Instruction *op, MLFuncGlobalLoweringState *funcWiseState,
std::unique_ptr<PatternState> opState,
MLFuncLoweringRewriter *rewriter) const override {
VectorTransferRewriter<VectorTransferOpTy>(
diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp
index f2dae11112b..f55c2154f08 100644
--- a/mlir/lib/Transforms/MaterializeVectors.cpp
+++ b/mlir/lib/Transforms/MaterializeVectors.cpp
@@ -246,8 +246,8 @@ static SmallVector<unsigned, 8> delinearize(unsigned linearIndex,
return res;
}
-static OperationInst *
-instantiate(FuncBuilder *b, OperationInst *opInst, VectorType hwVectorType,
+static Instruction *
+instantiate(FuncBuilder *b, Instruction *opInst, VectorType hwVectorType,
DenseMap<const Value *, Value *> *substitutionsMap);
/// Not all Values belong to a program slice scoped within the immediately
@@ -391,7 +391,7 @@ reindexAffineIndices(FuncBuilder *b, VectorType hwVectorType,
/// - constant splat is replaced by constant splat of `hwVectorType`.
/// TODO(ntv): add more substitutions on a per-need basis.
static SmallVector<NamedAttribute, 1>
-materializeAttributes(OperationInst *opInst, VectorType hwVectorType) {
+materializeAttributes(Instruction *opInst, VectorType hwVectorType) {
SmallVector<NamedAttribute, 1> res;
for (auto a : opInst->getAttrs()) {
if (auto splat = a.second.dyn_cast<SplatElementsAttr>()) {
@@ -411,8 +411,8 @@ materializeAttributes(OperationInst *opInst, VectorType hwVectorType) {
/// substitutionsMap.
///
/// If the underlying substitution fails, this fails too and returns nullptr.
-static OperationInst *
-instantiate(FuncBuilder *b, OperationInst *opInst, VectorType hwVectorType,
+static Instruction *
+instantiate(FuncBuilder *b, Instruction *opInst, VectorType hwVectorType,
DenseMap<const Value *, Value *> *substitutionsMap) {
assert(!opInst->isa<VectorTransferReadOp>() &&
"Should call the function specialized for VectorTransferReadOp");
@@ -488,7 +488,7 @@ static AffineMap projectedPermutationMap(VectorTransferOpTy *transfer,
/// `hwVectorType` int the covering of the super-vector type. For a more
/// detailed description of the problem, see the description of
/// reindexAffineIndices.
-static OperationInst *
+static Instruction *
instantiate(FuncBuilder *b, VectorTransferReadOp *read, VectorType hwVectorType,
ArrayRef<unsigned> hwVectorInstance,
DenseMap<const Value *, Value *> *substitutionsMap) {
@@ -512,7 +512,7 @@ instantiate(FuncBuilder *b, VectorTransferReadOp *read, VectorType hwVectorType,
/// `hwVectorType` int the covering of th3e super-vector type. For a more
/// detailed description of the problem, see the description of
/// reindexAffineIndices.
-static OperationInst *
+static Instruction *
instantiate(FuncBuilder *b, VectorTransferWriteOp *write,
VectorType hwVectorType, ArrayRef<unsigned> hwVectorInstance,
DenseMap<const Value *, Value *> *substitutionsMap) {
@@ -555,21 +555,20 @@ static bool instantiateMaterialization(Instruction *inst,
// Create a builder here for unroll-and-jam effects.
FuncBuilder b(inst);
- auto *opInst = cast<OperationInst>(inst);
// AffineApplyOp are ignored: instantiating the proper vector op will take
// care of AffineApplyOps by composing them properly.
- if (opInst->isa<AffineApplyOp>()) {
+ if (inst->isa<AffineApplyOp>()) {
return false;
}
- if (opInst->getNumBlockLists() != 0)
+ if (inst->getNumBlockLists() != 0)
return inst->emitError("NYI path Op with region");
- if (auto write = opInst->dyn_cast<VectorTransferWriteOp>()) {
+ if (auto write = inst->dyn_cast<VectorTransferWriteOp>()) {
auto *clone = instantiate(&b, write, state->hwVectorType,
state->hwVectorInstance, state->substitutionsMap);
return clone == nullptr;
}
- if (auto read = opInst->dyn_cast<VectorTransferReadOp>()) {
+ if (auto read = inst->dyn_cast<VectorTransferReadOp>()) {
auto *clone = instantiate(&b, read, state->hwVectorType,
state->hwVectorInstance, state->substitutionsMap);
if (!clone) {
@@ -582,19 +581,19 @@ static bool instantiateMaterialization(Instruction *inst,
// The only op with 0 results reaching this point must, by construction, be
// VectorTransferWriteOps and have been caught above. Ops with >= 2 results
// are not yet supported. So just support 1 result.
- if (opInst->getNumResults() != 1) {
+ if (inst->getNumResults() != 1) {
return inst->emitError("NYI: ops with != 1 results");
}
- if (opInst->getResult(0)->getType() != state->superVectorType) {
+ if (inst->getResult(0)->getType() != state->superVectorType) {
return inst->emitError("Op does not return a supervector.");
}
auto *clone =
- instantiate(&b, opInst, state->hwVectorType, state->substitutionsMap);
+ instantiate(&b, inst, state->hwVectorType, state->substitutionsMap);
if (!clone) {
return true;
}
state->substitutionsMap->insert(
- std::make_pair(opInst->getResult(0), clone->getResult(0)));
+ std::make_pair(inst->getResult(0), clone->getResult(0)));
return false;
}
@@ -645,7 +644,7 @@ static bool emitSlice(MaterializationState *state,
}
LLVM_DEBUG(dbgs() << "\nMLFunction is now\n");
- LLVM_DEBUG(cast<OperationInst>((*slice)[0])->getFunction()->print(dbgs()));
+ LLVM_DEBUG((*slice)[0]->getFunction()->print(dbgs()));
// slice are topologically sorted, we can just erase them in reverse
// order. Reverse iterator does not just work simply with an operator*
@@ -677,7 +676,7 @@ static bool emitSlice(MaterializationState *state,
/// scope.
/// TODO(ntv): please document return value.
static bool materialize(Function *f,
- const SetVector<OperationInst *> &terminators,
+ const SetVector<Instruction *> &terminators,
MaterializationState *state) {
DenseSet<Instruction *> seen;
DominanceInfo domInfo(f);
@@ -757,18 +756,17 @@ PassResult MaterializeVectorsPass::runOnFunction(Function *f) {
// Capture terminators; i.e. vector_transfer_write ops involving a strict
// super-vector of subVectorType.
auto filter = [subVectorType](const Instruction &inst) {
- const auto &opInst = cast<OperationInst>(inst);
- if (!opInst.isa<VectorTransferWriteOp>()) {
+ if (!inst.isa<VectorTransferWriteOp>()) {
return false;
}
- return matcher::operatesOnSuperVectors(opInst, subVectorType);
+ return matcher::operatesOnSuperVectors(inst, subVectorType);
};
auto pat = Op(filter);
SmallVector<NestedMatch, 8> matches;
pat.match(f, &matches);
- SetVector<OperationInst *> terminators;
+ SetVector<Instruction *> terminators;
for (auto m : matches) {
- terminators.insert(cast<OperationInst>(m.getMatchedInstruction()));
+ terminators.insert(m.getMatchedInstruction());
}
auto fail = materialize(f, terminators, &state);
diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp
index b9386c384dd..b2b69dc7b6d 100644
--- a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp
+++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp
@@ -75,12 +75,12 @@ struct MemRefDataFlowOpt : public FunctionPass, InstWalker<MemRefDataFlowOpt> {
PassResult runOnFunction(Function *f) override;
- void visitInstruction(OperationInst *opInst);
+ void visitInstruction(Instruction *opInst);
// A list of memref's that are potentially dead / could be eliminated.
SmallPtrSet<Value *, 4> memrefsToErase;
// Load op's whose results were replaced by those forwarded from stores.
- std::vector<OperationInst *> loadOpsToErase;
+ std::vector<Instruction *> loadOpsToErase;
DominanceInfo *domInfo = nullptr;
PostDominanceInfo *postDomInfo = nullptr;
@@ -100,22 +100,22 @@ FunctionPass *mlir::createMemRefDataFlowOptPass() {
// This is a straightforward implementation not optimized for speed. Optimize
// this in the future if needed.
-void MemRefDataFlowOpt::visitInstruction(OperationInst *opInst) {
- OperationInst *lastWriteStoreOp = nullptr;
+void MemRefDataFlowOpt::visitInstruction(Instruction *opInst) {
+ Instruction *lastWriteStoreOp = nullptr;
auto loadOp = opInst->dyn_cast<LoadOp>();
if (!loadOp)
return;
- OperationInst *loadOpInst = opInst;
+ Instruction *loadOpInst = opInst;
// First pass over the use list to get minimum number of surrounding
// loops common between the load op and the store op, with min taken across
// all store ops.
- SmallVector<OperationInst *, 8> storeOps;
+ SmallVector<Instruction *, 8> storeOps;
unsigned minSurroundingLoops = getNestingDepth(*loadOpInst);
for (InstOperand &use : loadOp->getMemRef()->getUses()) {
- auto storeOp = cast<OperationInst>(use.getOwner())->dyn_cast<StoreOp>();
+ auto storeOp = use.getOwner()->dyn_cast<StoreOp>();
if (!storeOp)
continue;
auto *storeOpInst = storeOp->getInstruction();
@@ -131,11 +131,11 @@ void MemRefDataFlowOpt::visitInstruction(OperationInst *opInst) {
// and loadOp.
// The list of store op candidates for forwarding - need to satisfy the
// conditions listed at the top.
- SmallVector<OperationInst *, 8> fwdingCandidates;
+ SmallVector<Instruction *, 8> fwdingCandidates;
// Store ops that have a dependence into the load (even if they aren't
// forwarding candidates). Each forwarding candidate will be checked for a
// post-dominance on these. 'fwdingCandidates' are a subset of depSrcStores.
- SmallVector<OperationInst *, 8> depSrcStores;
+ SmallVector<Instruction *, 8> depSrcStores;
for (auto *storeOpInst : storeOps) {
MemRefAccess srcAccess(storeOpInst);
MemRefAccess destAccess(loadOpInst);
@@ -197,7 +197,7 @@ void MemRefDataFlowOpt::visitInstruction(OperationInst *opInst) {
// that postdominates all 'depSrcStores' (if such a store exists) is the
// unique store providing the value to the load, i.e., provably the last
// writer to that memref loc.
- if (llvm::all_of(depSrcStores, [&](OperationInst *depStore) {
+ if (llvm::all_of(depSrcStores, [&](Instruction *depStore) {
return postDomInfo->postDominates(storeOpInst, depStore);
})) {
lastWriteStoreOp = storeOpInst;
@@ -246,24 +246,22 @@ PassResult MemRefDataFlowOpt::runOnFunction(Function *f) {
// to do this as well, but we'll do it here since we collected these anyway.
for (auto *memref : memrefsToErase) {
// If the memref hasn't been alloc'ed in this function, skip.
- OperationInst *defInst = memref->getDefiningInst();
+ Instruction *defInst = memref->getDefiningInst();
if (!defInst || !defInst->isa<AllocOp>())
// TODO(mlir-team): if the memref was returned by a 'call' instruction, we
// could still erase it if the call had no side-effects.
continue;
if (std::any_of(memref->use_begin(), memref->use_end(),
[&](InstOperand &use) {
- auto *ownerInst = cast<OperationInst>(use.getOwner());
+ auto *ownerInst = use.getOwner();
return (!ownerInst->isa<StoreOp>() &&
!ownerInst->isa<DeallocOp>());
}))
continue;
// Erase all stores, the dealloc, and the alloc on the memref.
- for (auto it = memref->use_begin(), e = memref->use_end(); it != e;) {
- auto &use = *(it++);
- cast<OperationInst>(use.getOwner())->erase();
- }
+ for (auto &use : llvm::make_early_inc_range(memref->getUses()))
+ use.getOwner()->erase();
defInst->erase();
}
diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp
index 8d13800160d..ba3be5e95f4 100644
--- a/mlir/lib/Transforms/PipelineDataTransfer.cpp
+++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp
@@ -61,7 +61,7 @@ FunctionPass *mlir::createPipelineDataTransferPass() {
// Returns the position of the tag memref operand given a DMA instruction.
// Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are
// added. TODO(b/117228571)
-static unsigned getTagMemRefPos(const OperationInst &dmaInst) {
+static unsigned getTagMemRefPos(const Instruction &dmaInst) {
assert(dmaInst.isa<DmaStartOp>() || dmaInst.isa<DmaWaitOp>());
if (dmaInst.isa<DmaStartOp>()) {
// Second to last operand.
@@ -142,7 +142,7 @@ PassResult PipelineDataTransfer::runOnFunction(Function *f) {
// deleted and replaced by a prologue, a new steady-state loop and an
// epilogue).
forOps.clear();
- f->walkPostOrder([&](OperationInst *opInst) {
+ f->walkPostOrder([&](Instruction *opInst) {
if (auto forOp = opInst->dyn_cast<AffineForOp>())
forOps.push_back(forOp);
});
@@ -180,33 +180,26 @@ static bool checkTagMatch(OpPointer<DmaStartOp> startOp,
// Identify matching DMA start/finish instructions to overlap computation with.
static void findMatchingStartFinishInsts(
OpPointer<AffineForOp> forOp,
- SmallVectorImpl<std::pair<OperationInst *, OperationInst *>>
- &startWaitPairs) {
+ SmallVectorImpl<std::pair<Instruction *, Instruction *>> &startWaitPairs) {
// Collect outgoing DMA instructions - needed to check for dependences below.
SmallVector<OpPointer<DmaStartOp>, 4> outgoingDmaOps;
for (auto &inst : *forOp->getBody()) {
- auto *opInst = dyn_cast<OperationInst>(&inst);
- if (!opInst)
- continue;
OpPointer<DmaStartOp> dmaStartOp;
- if ((dmaStartOp = opInst->dyn_cast<DmaStartOp>()) &&
+ if ((dmaStartOp = inst.dyn_cast<DmaStartOp>()) &&
dmaStartOp->isSrcMemorySpaceFaster())
outgoingDmaOps.push_back(dmaStartOp);
}
- SmallVector<OperationInst *, 4> dmaStartInsts, dmaFinishInsts;
+ SmallVector<Instruction *, 4> dmaStartInsts, dmaFinishInsts;
for (auto &inst : *forOp->getBody()) {
- auto *opInst = dyn_cast<OperationInst>(&inst);
- if (!opInst)
- continue;
// Collect DMA finish instructions.
- if (opInst->isa<DmaWaitOp>()) {
- dmaFinishInsts.push_back(opInst);
+ if (inst.isa<DmaWaitOp>()) {
+ dmaFinishInsts.push_back(&inst);
continue;
}
OpPointer<DmaStartOp> dmaStartOp;
- if (!(dmaStartOp = opInst->dyn_cast<DmaStartOp>()))
+ if (!(dmaStartOp = inst.dyn_cast<DmaStartOp>()))
continue;
// Only DMAs incoming into higher memory spaces are pipelined for now.
// TODO(bondhugula): handle outgoing DMA pipelining.
@@ -236,7 +229,7 @@ static void findMatchingStartFinishInsts(
}
}
if (!escapingUses)
- dmaStartInsts.push_back(opInst);
+ dmaStartInsts.push_back(&inst);
}
// For each start instruction, we look for a matching finish instruction.
@@ -262,7 +255,7 @@ PipelineDataTransfer::runOnAffineForOp(OpPointer<AffineForOp> forOp) {
return success();
}
- SmallVector<std::pair<OperationInst *, OperationInst *>, 4> startWaitPairs;
+ SmallVector<std::pair<Instruction *, Instruction *>, 4> startWaitPairs;
findMatchingStartFinishInsts(forOp, startWaitPairs);
if (startWaitPairs.empty()) {
@@ -335,7 +328,7 @@ PipelineDataTransfer::runOnAffineForOp(OpPointer<AffineForOp> forOp) {
} else {
// If a slice wasn't created, the reachable affine_apply op's from its
// operands are the ones that go with it.
- SmallVector<OperationInst *, 4> affineApplyInsts;
+ SmallVector<Instruction *, 4> affineApplyInsts;
SmallVector<Value *, 4> operands(dmaStartInst->getOperands());
getReachableAffineApplyOps(operands, affineApplyInsts);
for (const auto *inst : affineApplyInsts) {
@@ -356,13 +349,13 @@ PipelineDataTransfer::runOnAffineForOp(OpPointer<AffineForOp> forOp) {
for (auto &inst : *forOp->getBody()) {
assert(instShiftMap.find(&inst) != instShiftMap.end());
shifts[s++] = instShiftMap[&inst];
- LLVM_DEBUG(
- // Tagging instructions with shifts for debugging purposes.
- if (auto *opInst = dyn_cast<OperationInst>(&inst)) {
- FuncBuilder b(opInst);
- opInst->setAttr(b.getIdentifier("shift"),
- b.getI64IntegerAttr(shifts[s - 1]));
- });
+
+ // Tagging instructions with shifts for debugging purposes.
+ LLVM_DEBUG({
+ FuncBuilder b(&inst);
+ inst.setAttr(b.getIdentifier("shift"),
+ b.getI64IntegerAttr(shifts[s - 1]));
+ });
}
if (!isInstwiseShiftValid(forOp, shifts)) {
diff --git a/mlir/lib/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Transforms/SimplifyAffineStructures.cpp
index a9fcfc5bd11..29509911e31 100644
--- a/mlir/lib/Transforms/SimplifyAffineStructures.cpp
+++ b/mlir/lib/Transforms/SimplifyAffineStructures.cpp
@@ -64,7 +64,7 @@ static IntegerSet simplifyIntegerSet(IntegerSet set) {
}
PassResult SimplifyAffineStructures::runOnFunction(Function *f) {
- f->walk([&](OperationInst *opInst) {
+ f->walk([&](Instruction *opInst) {
for (auto attr : opInst->getAttrs()) {
if (auto mapAttr = attr.second.dyn_cast<AffineMapAttr>()) {
MutableAffineMap mMap(mapAttr.getValue());
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 790f971bb58..45c57e2f307 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -38,13 +38,13 @@ public:
worklist.reserve(64);
// Add all operations to the worklist.
- fn->walk([&](OperationInst *inst) { addToWorklist(inst); });
+ fn->walk([&](Instruction *inst) { addToWorklist(inst); });
}
/// Perform the rewrites.
void simplifyFunction();
- void addToWorklist(OperationInst *op) {
+ void addToWorklist(Instruction *op) {
// Check to see if the worklist already contains this op.
if (worklistMap.count(op))
return;
@@ -53,7 +53,7 @@ public:
worklist.push_back(op);
}
- OperationInst *popFromWorklist() {
+ Instruction *popFromWorklist() {
auto *op = worklist.back();
worklist.pop_back();
@@ -65,7 +65,7 @@ public:
/// If the specified operation is in the worklist, remove it. If not, this is
/// a no-op.
- void removeFromWorklist(OperationInst *op) {
+ void removeFromWorklist(Instruction *op) {
auto it = worklistMap.find(op);
if (it != worklistMap.end()) {
assert(worklist[it->second] == op && "malformed worklist data structure");
@@ -77,7 +77,7 @@ public:
protected:
// Implement the hook for creating operations, and make sure that newly
// created ops are added to the worklist for processing.
- OperationInst *createOperation(const OperationState &state) override {
+ Instruction *createOperation(const OperationState &state) override {
auto *result = builder.createOperation(state);
addToWorklist(result);
return result;
@@ -85,20 +85,18 @@ protected:
// If an operation is about to be removed, make sure it is not in our
// worklist anymore because we'd get dangling references to it.
- void notifyOperationRemoved(OperationInst *op) override {
+ void notifyOperationRemoved(Instruction *op) override {
removeFromWorklist(op);
}
// When the root of a pattern is about to be replaced, it can trigger
// simplifications to its users - make sure to add them to the worklist
// before the root is changed.
- void notifyRootReplaced(OperationInst *op) override {
+ void notifyRootReplaced(Instruction *op) override {
for (auto *result : op->getResults())
// TODO: Add a result->getUsers() iterator.
- for (auto &user : result->getUses()) {
- if (auto *op = dyn_cast<OperationInst>(user.getOwner()))
- addToWorklist(op);
- }
+ for (auto &user : result->getUses())
+ addToWorklist(user.getOwner());
// TODO: Walk the operand list dropping them as we go. If any of them
// drop to zero uses, then add them to the worklist to allow them to be
@@ -116,13 +114,13 @@ private:
/// need to be revisited, plus their index in the worklist. This allows us to
/// efficiently remove operations from the worklist when they are erased from
/// the function, even if they aren't the root of a pattern.
- std::vector<OperationInst *> worklist;
- DenseMap<OperationInst *, unsigned> worklistMap;
+ std::vector<Instruction *> worklist;
+ DenseMap<Instruction *, unsigned> worklistMap;
/// As part of canonicalization, we move constants to the top of the entry
/// block of the current function and de-duplicate them. This keeps track of
/// constants we have done this for.
- DenseMap<std::pair<Attribute, Type>, OperationInst *> uniquedConstants;
+ DenseMap<std::pair<Attribute, Type>, Instruction *> uniquedConstants;
};
}; // end anonymous namespace
@@ -229,10 +227,8 @@ void GreedyPatternRewriteDriver::simplifyFunction() {
// revisit them.
//
// TODO: Add a result->getUsers() iterator.
- for (auto &operand : op->getResult(i)->getUses()) {
- if (auto *op = dyn_cast<OperationInst>(operand.getOwner()))
- addToWorklist(op);
- }
+ for (auto &operand : op->getResult(i)->getUses())
+ addToWorklist(operand.getOwner());
res->replaceAllUsesWith(cstValue);
}
@@ -267,10 +263,8 @@ void GreedyPatternRewriteDriver::simplifyFunction() {
if (res->use_empty()) // ignore dead uses.
continue;
- for (auto &operand : op->getResult(i)->getUses()) {
- if (auto *op = dyn_cast<OperationInst>(operand.getOwner()))
- addToWorklist(op);
- }
+ for (auto &operand : op->getResult(i)->getUses())
+ addToWorklist(operand.getOwner());
res->replaceAllUsesWith(resultValues[i]);
}
}
diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp
index 153557de04a..5bf17989bef 100644
--- a/mlir/lib/Transforms/Utils/LoopUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp
@@ -101,7 +101,7 @@ bool mlir::promoteIfSingleIteration(OpPointer<AffineForOp> forOp) {
// Replaces all IV uses to its single iteration value.
auto *iv = forOp->getInductionVar();
- OperationInst *forInst = forOp->getInstruction();
+ Instruction *forInst = forOp->getInstruction();
if (!iv->use_empty()) {
if (forOp->hasConstantLowerBound()) {
auto *mlFunc = forInst->getFunction();
@@ -135,7 +135,7 @@ bool mlir::promoteIfSingleIteration(OpPointer<AffineForOp> forOp) {
/// their body into the containing Block.
void mlir::promoteSingleIterationLoops(Function *f) {
// Gathers all innermost loops through a post order pruned walk.
- f->walkPostOrder([](OperationInst *inst) {
+ f->walkPostOrder([](Instruction *inst) {
if (auto forOp = inst->dyn_cast<AffineForOp>())
promoteIfSingleIteration(forOp);
});
@@ -394,11 +394,10 @@ bool mlir::loopUnrollByFactor(OpPointer<AffineForOp> forOp,
return false;
// Generate the cleanup loop if trip count isn't a multiple of unrollFactor.
- OperationInst *forInst = forOp->getInstruction();
+ Instruction *forInst = forOp->getInstruction();
if (getLargestDivisorOfTripCount(forOp) % unrollFactor != 0) {
FuncBuilder builder(forInst->getBlock(), ++Block::iterator(forInst));
- auto cleanupForInst =
- cast<OperationInst>(builder.clone(*forInst))->cast<AffineForOp>();
+ auto cleanupForInst = builder.clone(*forInst)->cast<AffineForOp>();
auto clLbMap = getCleanupLoopLowerBound(forOp, unrollFactor, &builder);
assert(clLbMap &&
"cleanup loop lower bound map for single result bound maps can "
diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp
index 879a4f4b585..524e8d542f5 100644
--- a/mlir/lib/Transforms/Utils/Utils.cpp
+++ b/mlir/lib/Transforms/Utils/Utils.cpp
@@ -37,7 +37,7 @@ using namespace mlir;
/// Return true if this operation dereferences one or more memref's.
// Temporary utility: will be replaced when this is modeled through
// side-effects/op traits. TODO(b/117228571)
-static bool isMemRefDereferencingOp(const OperationInst &op) {
+static bool isMemRefDereferencingOp(const Instruction &op) {
if (op.isa<LoadOp>() || op.isa<StoreOp>() || op.isa<DmaStartOp>() ||
op.isa<DmaWaitOp>())
return true;
@@ -76,12 +76,11 @@ bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef,
std::make_unique<PostDominanceInfo>(postDomInstFilter->getFunction());
// The ops where memref replacement succeeds are replaced with new ones.
- SmallVector<OperationInst *, 8> opsToErase;
+ SmallVector<Instruction *, 8> opsToErase;
// Walk all uses of old memref. Operation using the memref gets replaced.
- for (auto it = oldMemRef->use_begin(); it != oldMemRef->use_end();) {
- InstOperand &use = *(it++);
- auto *opInst = cast<OperationInst>(use.getOwner());
+ for (auto &use : llvm::make_early_inc_range(oldMemRef->getUses())) {
+ auto *opInst = use.getOwner();
// Skip this use if it's not dominated by domInstFilter.
if (domInstFilter && !domInfo->dominates(domInstFilter, opInst))
@@ -217,8 +216,7 @@ bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef,
/// uses besides this opInst; otherwise returns the list of affine_apply
/// operations created in output argument `sliceOps`.
void mlir::createAffineComputationSlice(
- OperationInst *opInst,
- SmallVectorImpl<OpPointer<AffineApplyOp>> *sliceOps) {
+ Instruction *opInst, SmallVectorImpl<OpPointer<AffineApplyOp>> *sliceOps) {
// Collect all operands that are results of affine apply ops.
SmallVector<Value *, 4> subOperands;
subOperands.reserve(opInst->getNumOperands());
@@ -230,7 +228,7 @@ void mlir::createAffineComputationSlice(
}
// Gather sequence of AffineApplyOps reachable from 'subOperands'.
- SmallVector<OperationInst *, 4> affineApplyOps;
+ SmallVector<Instruction *, 4> affineApplyOps;
getReachableAffineApplyOps(subOperands, affineApplyOps);
// Skip transforming if there are no affine maps to compose.
if (affineApplyOps.empty())
@@ -341,8 +339,7 @@ bool mlir::constantFoldBounds(OpPointer<AffineForOp> forInst) {
}
void mlir::remapFunctionAttrs(
- OperationInst &op,
- const DenseMap<Attribute, FunctionAttr> &remappingTable) {
+ Instruction &op, const DenseMap<Attribute, FunctionAttr> &remappingTable) {
for (auto attr : op.getAttrs()) {
// Do the remapping, if we got the same thing back, then it must contain
// functions that aren't getting remapped.
diff --git a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp
index a9b9752ef51..7d51637a6e1 100644
--- a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp
+++ b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp
@@ -110,17 +110,13 @@ void VectorizerTestPass::testVectorShapeRatio(Function *f) {
// Only filter instructions that operate on a strict super-vector and have one
// return. This makes testing easier.
auto filter = [subVectorType](const Instruction &inst) {
- auto *opInst = dyn_cast<OperationInst>(&inst);
- if (!opInst) {
- return false;
- }
assert(subVectorType.getElementType() ==
Type::getF32(subVectorType.getContext()) &&
"Only f32 supported for now");
- if (!matcher::operatesOnSuperVectors(*opInst, subVectorType)) {
+ if (!matcher::operatesOnSuperVectors(inst, subVectorType)) {
return false;
}
- if (opInst->getNumResults() != 1) {
+ if (inst.getNumResults() != 1) {
return false;
}
return true;
@@ -129,7 +125,7 @@ void VectorizerTestPass::testVectorShapeRatio(Function *f) {
SmallVector<NestedMatch, 8> matches;
pat.match(f, &matches);
for (auto m : matches) {
- auto *opInst = cast<OperationInst>(m.getMatchedInstruction());
+ auto *opInst = m.getMatchedInstruction();
// This is a unit test that only checks and prints shape ratio.
// As a consequence we write only Ops with a single return type for the
// purpose of this test. If we need to test more intricate behavior in the
@@ -159,8 +155,7 @@ static NestedPattern patternTestSlicingOps() {
using matcher::Op;
// Match all OpInstructions with the kTestSlicingOpName name.
auto filter = [](const Instruction &inst) {
- const auto &opInst = cast<OperationInst>(inst);
- return opInst.getName().getStringRef() == kTestSlicingOpName;
+ return inst.getName().getStringRef() == kTestSlicingOpName;
};
return Op(filter);
}
@@ -209,8 +204,7 @@ void VectorizerTestPass::testSlicing(Function *f) {
}
static bool customOpWithAffineMapAttribute(const Instruction &inst) {
- const auto &opInst = cast<OperationInst>(inst);
- return opInst.getName().getStringRef() ==
+ return inst.getName().getStringRef() ==
VectorizerTestPass::kTestAffineMapOpName;
}
@@ -222,7 +216,7 @@ void VectorizerTestPass::testComposeMaps(Function *f) {
SmallVector<AffineMap, 4> maps;
maps.reserve(matches.size());
for (auto m : llvm::reverse(matches)) {
- auto *opInst = cast<OperationInst>(m.getMatchedInstruction());
+ auto *opInst = m.getMatchedInstruction();
auto map = opInst->getAttr(VectorizerTestPass::kTestAffineMapAttrName)
.cast<AffineMapAttr>()
.getValue();
@@ -236,13 +230,11 @@ void VectorizerTestPass::testComposeMaps(Function *f) {
}
static bool affineApplyOp(const Instruction &inst) {
- const auto &opInst = cast<OperationInst>(inst);
- return opInst.isa<AffineApplyOp>();
+ return inst.isa<AffineApplyOp>();
}
static bool singleResultAffineApplyOpWithoutUses(const Instruction &inst) {
- const auto &opInst = cast<OperationInst>(inst);
- auto app = opInst.dyn_cast<AffineApplyOp>();
+ auto app = inst.dyn_cast<AffineApplyOp>();
return app && app->use_empty();
}
@@ -259,8 +251,7 @@ void VectorizerTestPass::testNormalizeMaps(Function *f) {
SmallVector<NestedMatch, 8> matches;
pattern.match(f, &matches);
for (auto m : matches) {
- auto app =
- cast<OperationInst>(m.getMatchedInstruction())->cast<AffineApplyOp>();
+ auto app = m.getMatchedInstruction()->cast<AffineApplyOp>();
FuncBuilder b(m.getMatchedInstruction());
SmallVector<Value *, 8> operands(app->getOperands());
makeComposedAffineApply(&b, app->getLoc(), app->getAffineMap(), operands);
diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp
index 661861dcfd4..5a8d5d24661 100644
--- a/mlir/lib/Transforms/Vectorize.cpp
+++ b/mlir/lib/Transforms/Vectorize.cpp
@@ -723,22 +723,22 @@ namespace {
struct VectorizationState {
/// Adds an entry of pre/post vectorization instructions in the state.
- void registerReplacement(OperationInst *key, OperationInst *value);
+ void registerReplacement(Instruction *key, Instruction *value);
/// When the current vectorization pattern is successful, this erases the
/// instructions that were marked for erasure in the proper order and resets
/// the internal state for the next pattern.
void finishVectorizationPattern();
- // In-order tracking of original OperationInst that have been vectorized.
+ // In-order tracking of original Instruction that have been vectorized.
// Erase in reverse order.
- SmallVector<OperationInst *, 16> toErase;
- // Set of OperationInst that have been vectorized (the values in the
+ SmallVector<Instruction *, 16> toErase;
+ // Set of Instruction that have been vectorized (the values in the
// vectorizationMap for hashed access). The vectorizedSet is used in
// particular to filter the instructions that have already been vectorized by
// this pattern, when iterating over nested loops in this pattern.
- DenseSet<OperationInst *> vectorizedSet;
- // Map of old scalar OperationInst to new vectorized OperationInst.
- DenseMap<OperationInst *, OperationInst *> vectorizationMap;
+ DenseSet<Instruction *> vectorizedSet;
+ // Map of old scalar Instruction to new vectorized Instruction.
+ DenseMap<Instruction *, Instruction *> vectorizationMap;
// Map of old scalar Value to new vectorized Value.
DenseMap<const Value *, Value *> replacementMap;
// The strategy drives which loop to vectorize by which amount.
@@ -747,17 +747,17 @@ struct VectorizationState {
// vectorizeOperations function. They consist of the subset of load operations
// that have been vectorized. They can be retrieved from `vectorizationMap`
// but it is convenient to keep track of them in a separate data structure.
- DenseSet<OperationInst *> roots;
+ DenseSet<Instruction *> roots;
// Terminator instructions for the worklist in the vectorizeOperations
// function. They consist of the subset of store operations that have been
// vectorized. They can be retrieved from `vectorizationMap` but it is
// convenient to keep track of them in a separate data structure. Since they
// do not necessarily belong to use-def chains starting from loads (e.g
// storing a constant), we need to handle them in a post-pass.
- DenseSet<OperationInst *> terminators;
+ DenseSet<Instruction *> terminators;
// Checks that the type of `inst` is StoreOp and adds it to the terminators
// set.
- void registerTerminator(OperationInst *inst);
+ void registerTerminator(Instruction *inst);
private:
void registerReplacement(const Value *key, Value *value);
@@ -765,8 +765,8 @@ private:
} // end namespace
-void VectorizationState::registerReplacement(OperationInst *key,
- OperationInst *value) {
+void VectorizationState::registerReplacement(Instruction *key,
+ Instruction *value) {
LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ commit vectorized op: ");
LLVM_DEBUG(key->print(dbgs()));
LLVM_DEBUG(dbgs() << " into ");
@@ -785,7 +785,7 @@ void VectorizationState::registerReplacement(OperationInst *key,
}
}
-void VectorizationState::registerTerminator(OperationInst *inst) {
+void VectorizationState::registerTerminator(Instruction *inst) {
assert(inst->isa<StoreOp>() && "terminator must be a StoreOp");
assert(terminators.count(inst) == 0 &&
"terminator was already inserted previously");
@@ -867,17 +867,16 @@ static bool vectorizeAffineForOp(AffineForOp *loop, int64_t step,
if (!matcher::isLoadOrStore(inst)) {
return false;
}
- auto *opInst = cast<OperationInst>(&inst);
- return state->vectorizationMap.count(opInst) == 0 &&
- state->vectorizedSet.count(opInst) == 0 &&
- state->roots.count(opInst) == 0 &&
- state->terminators.count(opInst) == 0;
+ return state->vectorizationMap.count(&inst) == 0 &&
+ state->vectorizedSet.count(&inst) == 0 &&
+ state->roots.count(&inst) == 0 &&
+ state->terminators.count(&inst) == 0;
};
auto loadAndStores = matcher::Op(notVectorizedThisPattern);
SmallVector<NestedMatch, 8> loadAndStoresMatches;
loadAndStores.match(loop->getInstruction(), &loadAndStoresMatches);
for (auto ls : loadAndStoresMatches) {
- auto *opInst = cast<OperationInst>(ls.getMatchedInstruction());
+ auto *opInst = ls.getMatchedInstruction();
auto load = opInst->dyn_cast<LoadOp>();
auto store = opInst->dyn_cast<StoreOp>();
LLVM_DEBUG(opInst->print(dbgs()));
@@ -900,7 +899,7 @@ static bool vectorizeAffineForOp(AffineForOp *loop, int64_t step,
static FilterFunctionType
isVectorizableLoopPtrFactory(unsigned fastestVaryingMemRefDimension) {
return [fastestVaryingMemRefDimension](const Instruction &forInst) {
- auto loop = cast<OperationInst>(forInst).cast<AffineForOp>();
+ auto loop = forInst.cast<AffineForOp>();
return isVectorizableLoopAlongFastestVaryingMemRefDim(
loop, fastestVaryingMemRefDimension);
};
@@ -915,7 +914,7 @@ static bool vectorizeNonRoot(ArrayRef<NestedMatch> matches,
/// recursively in DFS post-order.
static bool doVectorize(NestedMatch oneMatch, VectorizationState *state) {
auto *loopInst = oneMatch.getMatchedInstruction();
- auto loop = cast<OperationInst>(loopInst)->cast<AffineForOp>();
+ auto loop = loopInst->cast<AffineForOp>();
auto childrenMatches = oneMatch.getMatchedChildren();
// 1. DFS postorder recursion, if any of my children fails, I fail too.
@@ -977,15 +976,14 @@ static Value *vectorizeConstant(Instruction *inst, const ConstantOp &constant,
Location loc = inst->getLoc();
auto vectorType = type.cast<VectorType>();
auto attr = SplatElementsAttr::get(vectorType, constant.getValue());
- auto *constantOpInst = cast<OperationInst>(constant.getInstruction());
+ auto *constantOpInst = constant.getInstruction();
OperationState state(
b.getContext(), loc, constantOpInst->getName().getStringRef(), {},
{vectorType},
{make_pair(Identifier::get("value", b.getContext()), attr)});
- auto *splat = cast<OperationInst>(b.createOperation(state));
- return splat->getResult(0);
+ return b.createOperation(state)->getResult(0);
}
/// Returns a uniqu'ed VectorType.
@@ -997,8 +995,7 @@ static Type getVectorType(Value *v, const VectorizationState &state) {
if (!VectorType::isValidElementType(v->getType())) {
return Type();
}
- auto *definingOpInst = cast<OperationInst>(v->getDefiningInst());
- if (state.vectorizedSet.count(definingOpInst) > 0) {
+ if (state.vectorizedSet.count(v->getDefiningInst()) > 0) {
return v->getType().cast<VectorType>();
}
return VectorType::get(state.strategy->vectorSizes, v->getType());
@@ -1029,9 +1026,8 @@ static Value *vectorizeOperand(Value *operand, Instruction *inst,
VectorizationState *state) {
LLVM_DEBUG(dbgs() << "\n[early-vect]vectorize operand: ");
LLVM_DEBUG(operand->print(dbgs()));
- auto *definingInstruction = cast<OperationInst>(operand->getDefiningInst());
// 1. If this value has already been vectorized this round, we are done.
- if (state->vectorizedSet.count(definingInstruction) > 0) {
+ if (state->vectorizedSet.count(operand->getDefiningInst()) > 0) {
LLVM_DEBUG(dbgs() << " -> already vector operand");
return operand;
}
@@ -1062,7 +1058,7 @@ static Value *vectorizeOperand(Value *operand, Instruction *inst,
return nullptr;
};
-/// Encodes OperationInst-specific behavior for vectorization. In general we
+/// Encodes Instruction-specific behavior for vectorization. In general we
/// assume that all operands of an op must be vectorized but this is not always
/// true. In the future, it would be nice to have a trait that describes how a
/// particular operation vectorizes. For now we implement the case distinction
@@ -1071,9 +1067,8 @@ static Value *vectorizeOperand(Value *operand, Instruction *inst,
/// TODO(ntv): consider adding a trait to Op to describe how it gets vectorized.
/// Maybe some Ops are not vectorizable or require some tricky logic, we cannot
/// do one-off logic here; ideally it would be TableGen'd.
-static OperationInst *vectorizeOneOperationInst(FuncBuilder *b,
- OperationInst *opInst,
- VectorizationState *state) {
+static Instruction *vectorizeOneInstruction(FuncBuilder *b, Instruction *opInst,
+ VectorizationState *state) {
// Sanity checks.
assert(!opInst->isa<LoadOp>() &&
"all loads must have already been fully vectorized independently");
@@ -1094,7 +1089,7 @@ static OperationInst *vectorizeOneOperationInst(FuncBuilder *b,
LLVM_DEBUG(permutationMap.print(dbgs()));
auto transfer = b.create<VectorTransferWriteOp>(
opInst->getLoc(), vectorValue, memRef, indices, permutationMap);
- auto *res = cast<OperationInst>(transfer->getInstruction());
+ auto *res = transfer->getInstruction();
LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ vectorized store: " << *res);
// "Terminators" (i.e. StoreOps) are erased on the spot.
opInst->erase();
@@ -1119,8 +1114,8 @@ static OperationInst *vectorizeOneOperationInst(FuncBuilder *b,
// Create a clone of the op with the proper operands and return types.
// TODO(ntv): The following assumes there is always an op with a fixed
// name that works both in scalar mode and vector mode.
- // TODO(ntv): Is it worth considering an OperationInst.clone operation
- // which changes the type so we can promote an OperationInst with less
+ // TODO(ntv): Is it worth considering an Instruction.clone operation
+ // which changes the type so we can promote an Instruction with less
// boilerplate?
OperationState newOp(b->getContext(), opInst->getLoc(),
opInst->getName().getStringRef(), operands, types,
@@ -1129,22 +1124,22 @@ static OperationInst *vectorizeOneOperationInst(FuncBuilder *b,
return b->createOperation(newOp);
}
-/// Iterates over the OperationInst in the loop and rewrites them using their
+/// Iterates over the Instruction in the loop and rewrites them using their
/// vectorized counterpart by:
-/// 1. iteratively building a worklist of uses of the OperationInst vectorized
+/// 1. iteratively building a worklist of uses of the Instruction vectorized
/// so far by this pattern;
-/// 2. for each OperationInst in the worklist, create the vector form of this
+/// 2. for each Instruction in the worklist, create the vector form of this
/// operation and replace all its uses by the vectorized form. For this step,
/// the worklist must be traversed in order;
/// 3. verify that all operands of the newly vectorized operation have been
/// vectorized by this pattern.
static bool vectorizeOperations(VectorizationState *state) {
// 1. create initial worklist with the uses of the roots.
- SetVector<OperationInst *> worklist;
- auto insertUsesOf = [&worklist, state](OperationInst *vectorized) {
+ SetVector<Instruction *> worklist;
+ auto insertUsesOf = [&worklist, state](Instruction *vectorized) {
for (auto *r : vectorized->getResults())
for (auto &u : r->getUses()) {
- auto *inst = cast<OperationInst>(u.getOwner());
+ auto *inst = u.getOwner();
// Don't propagate to terminals, a separate pass is needed for those.
// TODO(ntv)[b/119759136]: use isa<> once Op is implemented.
if (state->terminators.count(inst) > 0) {
@@ -1166,7 +1161,7 @@ static bool vectorizeOperations(VectorizationState *state) {
// 2. Create vectorized form of the instruction.
// Insert it just before inst, on success register inst as replaced.
FuncBuilder b(inst);
- auto *vectorizedInst = vectorizeOneOperationInst(&b, inst, state);
+ auto *vectorizedInst = vectorizeOneInstruction(&b, inst, state);
if (!vectorizedInst) {
return true;
}
@@ -1179,7 +1174,7 @@ static bool vectorizeOperations(VectorizationState *state) {
// 4. Augment the worklist with uses of the instruction we just vectorized.
// This preserves the proper order in the worklist.
- apply(insertUsesOf, ArrayRef<OperationInst *>{inst});
+ apply(insertUsesOf, ArrayRef<Instruction *>{inst});
}
return false;
}
@@ -1189,8 +1184,7 @@ static bool vectorizeOperations(VectorizationState *state) {
/// Each root may succeed independently but will otherwise clean after itself if
/// anything below it fails.
static bool vectorizeRootMatch(NestedMatch m, VectorizationStrategy *strategy) {
- auto loop =
- cast<OperationInst>(m.getMatchedInstruction())->cast<AffineForOp>();
+ auto loop = m.getMatchedInstruction()->cast<AffineForOp>();
VectorizationState state;
state.strategy = strategy;
@@ -1207,8 +1201,7 @@ static bool vectorizeRootMatch(NestedMatch m, VectorizationStrategy *strategy) {
}
auto *loopInst = loop->getInstruction();
FuncBuilder builder(loopInst);
- auto clonedLoop =
- cast<OperationInst>(builder.clone(*loopInst))->cast<AffineForOp>();
+ auto clonedLoop = builder.clone(*loopInst)->cast<AffineForOp>();
auto fail = doVectorize(m, &state);
/// Sets up error handling for this root loop. This is how the root match
@@ -1248,12 +1241,12 @@ static bool vectorizeRootMatch(NestedMatch m, VectorizationStrategy *strategy) {
}
// Finally, vectorize the terminators. If anything fails to vectorize, skip.
- auto vectorizeOrFail = [&fail, &state](OperationInst *inst) {
+ auto vectorizeOrFail = [&fail, &state](Instruction *inst) {
if (fail) {
return;
}
FuncBuilder b(inst);
- auto *res = vectorizeOneOperationInst(&b, inst, &state);
+ auto *res = vectorizeOneInstruction(&b, inst, &state);
if (res == nullptr) {
fail = true;
}
OpenPOWER on IntegriCloud