summaryrefslogtreecommitdiffstats
path: root/mlir/lib
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Analysis/AffineAnalysis.cpp108
-rw-r--r--mlir/lib/Analysis/AffineStructures.cpp26
-rw-r--r--mlir/lib/Analysis/Dominance.cpp2
-rw-r--r--mlir/lib/Analysis/LoopAnalysis.cpp101
-rw-r--r--mlir/lib/Analysis/MLFunctionMatcher.cpp80
-rw-r--r--mlir/lib/Analysis/MemRefBoundCheck.cpp12
-rw-r--r--mlir/lib/Analysis/MemRefDependenceCheck.cpp44
-rw-r--r--mlir/lib/Analysis/OpStats.cpp12
-rw-r--r--mlir/lib/Analysis/SliceAnalysis.cpp114
-rw-r--r--mlir/lib/Analysis/Utils.cpp152
-rw-r--r--mlir/lib/Analysis/VectorAnalysis.cpp55
-rw-r--r--mlir/lib/Analysis/Verifier.cpp53
-rw-r--r--mlir/lib/IR/AsmPrinter.cpp136
-rw-r--r--mlir/lib/IR/Block.cpp28
-rw-r--r--mlir/lib/IR/Builders.cpp22
-rw-r--r--mlir/lib/IR/Function.cpp14
-rw-r--r--mlir/lib/IR/Instruction.cpp (renamed from mlir/lib/IR/Statement.cpp)325
-rw-r--r--mlir/lib/IR/Operation.cpp4
-rw-r--r--mlir/lib/IR/PatternMatch.cpp2
-rw-r--r--mlir/lib/IR/Value.cpp14
-rw-r--r--mlir/lib/Parser/Parser.cpp90
-rw-r--r--mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp2
-rw-r--r--mlir/lib/Transforms/CSE.cpp14
-rw-r--r--mlir/lib/Transforms/ComposeAffineMaps.cpp24
-rw-r--r--mlir/lib/Transforms/ConstantFold.cpp32
-rw-r--r--mlir/lib/Transforms/ConvertToCFG.cpp122
-rw-r--r--mlir/lib/Transforms/DmaGeneration.cpp74
-rw-r--r--mlir/lib/Transforms/LoopFusion.cpp180
-rw-r--r--mlir/lib/Transforms/LoopTiling.cpp58
-rw-r--r--mlir/lib/Transforms/LoopUnroll.cpp68
-rw-r--r--mlir/lib/Transforms/LoopUnrollAndJam.cpp100
-rw-r--r--mlir/lib/Transforms/LowerVectorTransfers.cpp12
-rw-r--r--mlir/lib/Transforms/MaterializeVectors.cpp104
-rw-r--r--mlir/lib/Transforms/PipelineDataTransfer.cpp200
-rw-r--r--mlir/lib/Transforms/SimplifyAffineExpr.cpp22
-rw-r--r--mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp4
-rw-r--r--mlir/lib/Transforms/Utils/LoopUtils.cpp288
-rw-r--r--mlir/lib/Transforms/Utils/Utils.cpp152
-rw-r--r--mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp48
-rw-r--r--mlir/lib/Transforms/Vectorize.cpp207
40 files changed, 1556 insertions, 1549 deletions
diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp
index f01735f26e1..8058af06b55 100644
--- a/mlir/lib/Analysis/AffineAnalysis.cpp
+++ b/mlir/lib/Analysis/AffineAnalysis.cpp
@@ -25,7 +25,7 @@
#include "mlir/Analysis/Utils.h"
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/Statements.h"
+#include "mlir/IR/Instructions.h"
#include "mlir/StandardOps/StandardOps.h"
#include "mlir/Support/Functional.h"
#include "mlir/Support/MathExtras.h"
@@ -498,22 +498,22 @@ void mlir::getReachableAffineApplyOps(
while (!worklist.empty()) {
State &state = worklist.back();
- auto *opStmt = state.value->getDefiningInst();
+ auto *opInst = state.value->getDefiningInst();
// Note: getDefiningInst will return nullptr if the operand is not an
- // OperationInst (i.e. ForStmt), which is a terminator for the search.
- if (opStmt == nullptr || !opStmt->isa<AffineApplyOp>()) {
+ // OperationInst (i.e. ForInst), which is a terminator for the search.
+ if (opInst == nullptr || !opInst->isa<AffineApplyOp>()) {
worklist.pop_back();
continue;
}
- if (auto affineApplyOp = opStmt->dyn_cast<AffineApplyOp>()) {
+ if (auto affineApplyOp = opInst->dyn_cast<AffineApplyOp>()) {
if (state.operandIndex == 0) {
- // Pre-Visit: Add 'opStmt' to reachable sequence.
- affineApplyOps.push_back(opStmt);
+ // Pre-Visit: Add 'opInst' to reachable sequence.
+ affineApplyOps.push_back(opInst);
}
- if (state.operandIndex < opStmt->getNumOperands()) {
+ if (state.operandIndex < opInst->getNumOperands()) {
// Visit: Add next 'affineApplyOp' operand to worklist.
// Get next operand to visit at 'operandIndex'.
- auto *nextOperand = opStmt->getOperand(state.operandIndex);
+ auto *nextOperand = opInst->getOperand(state.operandIndex);
// Increment 'operandIndex' in 'state'.
++state.operandIndex;
// Add 'nextOperand' to worklist.
@@ -533,47 +533,47 @@ void mlir::forwardSubstituteReachableOps(AffineValueMap *valueMap) {
SmallVector<OperationInst *, 4> affineApplyOps;
getReachableAffineApplyOps(valueMap->getOperands(), affineApplyOps);
// Compose AffineApplyOps in 'affineApplyOps'.
- for (auto *opStmt : affineApplyOps) {
- assert(opStmt->isa<AffineApplyOp>());
- auto affineApplyOp = opStmt->dyn_cast<AffineApplyOp>();
+ for (auto *opInst : affineApplyOps) {
+ assert(opInst->isa<AffineApplyOp>());
+ auto affineApplyOp = opInst->dyn_cast<AffineApplyOp>();
// Forward substitute 'affineApplyOp' into 'valueMap'.
valueMap->forwardSubstitute(*affineApplyOp);
}
}
// Builds a system of constraints with dimensional identifiers corresponding to
-// the loop IVs of the forStmts appearing in that order. Any symbols founds in
+// the loop IVs of the forInsts appearing in that order. Any symbols founds in
// the bound operands are added as symbols in the system. Returns false for the
// yet unimplemented cases.
// TODO(andydavis,bondhugula) Handle non-unit steps through local variables or
// stride information in FlatAffineConstraints. (For eg., by using iv - lb %
// step = 0 and/or by introducing a method in FlatAffineConstraints
// setExprStride(ArrayRef<int64_t> expr, int64_t stride)
-bool mlir::getIndexSet(ArrayRef<ForStmt *> forStmts,
+bool mlir::getIndexSet(ArrayRef<ForInst *> forInsts,
FlatAffineConstraints *domain) {
- SmallVector<Value *, 4> indices(forStmts.begin(), forStmts.end());
+ SmallVector<Value *, 4> indices(forInsts.begin(), forInsts.end());
// Reset while associated Values in 'indices' to the domain.
- domain->reset(forStmts.size(), /*numSymbols=*/0, /*numLocals=*/0, indices);
- for (auto *forStmt : forStmts) {
- // Add constraints from forStmt's bounds.
- if (!domain->addForStmtDomain(*forStmt))
+ domain->reset(forInsts.size(), /*numSymbols=*/0, /*numLocals=*/0, indices);
+ for (auto *forInst : forInsts) {
+ // Add constraints from forInst's bounds.
+ if (!domain->addForInstDomain(*forInst))
return false;
}
return true;
}
-// Computes the iteration domain for 'opStmt' and populates 'indexSet', which
-// encapsulates the constraints involving loops surrounding 'opStmt' and
+// Computes the iteration domain for 'opInst' and populates 'indexSet', which
+// encapsulates the constraints involving loops surrounding 'opInst' and
// potentially involving any Function symbols. The dimensional identifiers in
-// 'indexSet' correspond to the loops surounding 'stmt' from outermost to
+// 'indexSet' correspond to the loops surounding 'inst' from outermost to
// innermost.
-// TODO(andydavis) Add support to handle IfStmts surrounding 'stmt'.
-static bool getStmtIndexSet(const Statement *stmt,
+// TODO(andydavis) Add support to handle IfInsts surrounding 'inst'.
+static bool getInstIndexSet(const Instruction *inst,
FlatAffineConstraints *indexSet) {
- // TODO(andydavis) Extend this to gather enclosing IfStmts and consider
+ // TODO(andydavis) Extend this to gather enclosing IfInsts and consider
// factoring it out into a utility function.
- SmallVector<ForStmt *, 4> loops;
- getLoopIVs(*stmt, &loops);
+ SmallVector<ForInst *, 4> loops;
+ getLoopIVs(*inst, &loops);
return getIndexSet(loops, indexSet);
}
@@ -672,7 +672,7 @@ static void buildDimAndSymbolPositionMaps(
auto updateValuePosMap = [&](ArrayRef<Value *> values, bool isSrc) {
for (unsigned i = 0, e = values.size(); i < e; ++i) {
auto *value = values[i];
- if (!isa<ForStmt>(values[i]))
+ if (!isa<ForInst>(values[i]))
valuePosMap->addSymbolValue(value);
else if (isSrc)
valuePosMap->addSrcValue(value);
@@ -840,13 +840,13 @@ addMemRefAccessConstraints(const AffineValueMap &srcAccessMap,
// Add equality constraints for any operands that are defined by constant ops.
auto addEqForConstOperands = [&](ArrayRef<const Value *> operands) {
for (unsigned i = 0, e = operands.size(); i < e; ++i) {
- if (isa<ForStmt>(operands[i]))
+ if (isa<ForInst>(operands[i]))
continue;
auto *symbol = operands[i];
assert(symbol->isValidSymbol());
// Check if the symbol is a constant.
- if (auto *opStmt = symbol->getDefiningInst()) {
- if (auto constOp = opStmt->dyn_cast<ConstantIndexOp>()) {
+ if (auto *opInst = symbol->getDefiningInst()) {
+ if (auto constOp = opInst->dyn_cast<ConstantIndexOp>()) {
dependenceDomain->setIdToConstant(valuePosMap.getSymPos(symbol),
constOp->getValue());
}
@@ -909,8 +909,8 @@ static unsigned getNumCommonLoops(const FlatAffineConstraints &srcDomain,
std::min(srcDomain.getNumDimIds(), dstDomain.getNumDimIds());
unsigned numCommonLoops = 0;
for (unsigned i = 0; i < minNumLoops; ++i) {
- if (!isa<ForStmt>(srcDomain.getIdValue(i)) ||
- !isa<ForStmt>(dstDomain.getIdValue(i)) ||
+ if (!isa<ForInst>(srcDomain.getIdValue(i)) ||
+ !isa<ForInst>(dstDomain.getIdValue(i)) ||
srcDomain.getIdValue(i) != dstDomain.getIdValue(i))
break;
++numCommonLoops;
@@ -918,26 +918,26 @@ static unsigned getNumCommonLoops(const FlatAffineConstraints &srcDomain,
return numCommonLoops;
}
-// Returns Block common to 'srcAccess.opStmt' and 'dstAccess.opStmt'.
+// Returns Block common to 'srcAccess.opInst' and 'dstAccess.opInst'.
static Block *getCommonBlock(const MemRefAccess &srcAccess,
const MemRefAccess &dstAccess,
const FlatAffineConstraints &srcDomain,
unsigned numCommonLoops) {
if (numCommonLoops == 0) {
- auto *block = srcAccess.opStmt->getBlock();
+ auto *block = srcAccess.opInst->getBlock();
while (block->getContainingInst()) {
block = block->getContainingInst()->getBlock();
}
return block;
}
auto *commonForValue = srcDomain.getIdValue(numCommonLoops - 1);
- assert(isa<ForStmt>(commonForValue));
- return cast<ForStmt>(commonForValue)->getBody();
+ assert(isa<ForInst>(commonForValue));
+ return cast<ForInst>(commonForValue)->getBody();
}
-// Returns true if the ancestor operation statement of 'srcAccess' properly
-// dominates the ancestor operation statement of 'dstAccess' in the same
-// statement block. Returns false otherwise.
+// Returns true if the ancestor operation instruction of 'srcAccess' properly
+// dominates the ancestor operation instruction of 'dstAccess' in the same
+// instruction block. Returns false otherwise.
// Note that because 'srcAccess' or 'dstAccess' may be nested in conditionals,
// the function is named 'srcMayExecuteBeforeDst'.
// Note that 'numCommonLoops' is the number of contiguous surrounding outer
@@ -946,16 +946,16 @@ static bool srcMayExecuteBeforeDst(const MemRefAccess &srcAccess,
const MemRefAccess &dstAccess,
const FlatAffineConstraints &srcDomain,
unsigned numCommonLoops) {
- // Get Block common to 'srcAccess.opStmt' and 'dstAccess.opStmt'.
+ // Get Block common to 'srcAccess.opInst' and 'dstAccess.opInst'.
auto *commonBlock =
getCommonBlock(srcAccess, dstAccess, srcDomain, numCommonLoops);
// Check the dominance relationship between the respective ancestors of the
// src and dst in the Block of the innermost among the common loops.
- auto *srcStmt = commonBlock->findAncestorInstInBlock(*srcAccess.opStmt);
- assert(srcStmt != nullptr);
- auto *dstStmt = commonBlock->findAncestorInstInBlock(*dstAccess.opStmt);
- assert(dstStmt != nullptr);
- return mlir::properlyDominates(*srcStmt, *dstStmt);
+ auto *srcInst = commonBlock->findAncestorInstInBlock(*srcAccess.opInst);
+ assert(srcInst != nullptr);
+ auto *dstInst = commonBlock->findAncestorInstInBlock(*dstAccess.opInst);
+ assert(dstInst != nullptr);
+ return mlir::properlyDominates(*srcInst, *dstInst);
}
// Adds ordering constraints to 'dependenceDomain' based on number of loops
@@ -1119,7 +1119,7 @@ void MemRefAccess::getAccessMap(AffineValueMap *accessMap) const {
// until operands of the AffineValueMap are loop IVs or symbols.
// *) Build iteration domain constraints for each access. Iteration domain
// constraints are pairs of inequality contraints representing the
-// upper/lower loop bounds for each ForStmt in the loop nest associated
+// upper/lower loop bounds for each ForInst in the loop nest associated
// with each access.
// *) Build dimension and symbol position maps for each access, which map
// Values from access functions and iteration domains to their position
@@ -1197,7 +1197,7 @@ bool mlir::checkMemrefAccessDependence(
if (srcAccess.memref != dstAccess.memref)
return false;
// Return 'false' if one of these accesses is not a StoreOp.
- if (!srcAccess.opStmt->isa<StoreOp>() && !dstAccess.opStmt->isa<StoreOp>())
+ if (!srcAccess.opInst->isa<StoreOp>() && !dstAccess.opInst->isa<StoreOp>())
return false;
// Get composed access function for 'srcAccess'.
@@ -1208,19 +1208,19 @@ bool mlir::checkMemrefAccessDependence(
AffineValueMap dstAccessMap;
dstAccess.getAccessMap(&dstAccessMap);
- // Get iteration domain for the 'srcAccess' statement.
+ // Get iteration domain for the 'srcAccess' instruction.
FlatAffineConstraints srcDomain;
- if (!getStmtIndexSet(srcAccess.opStmt, &srcDomain))
+ if (!getInstIndexSet(srcAccess.opInst, &srcDomain))
return false;
- // Get iteration domain for 'dstAccess' statement.
+ // Get iteration domain for 'dstAccess' instruction.
FlatAffineConstraints dstDomain;
- if (!getStmtIndexSet(dstAccess.opStmt, &dstDomain))
+ if (!getInstIndexSet(dstAccess.opInst, &dstDomain))
return false;
// Return 'false' if loopDepth > numCommonLoops and if the ancestor operation
- // statement of 'srcAccess' does not properly dominate the ancestor operation
- // statement of 'dstAccess' in the same common statement block.
+ // instruction of 'srcAccess' does not properly dominate the ancestor
+ // operation instruction of 'dstAccess' in the same common instruction block.
unsigned numCommonLoops = getNumCommonLoops(srcDomain, dstDomain);
assert(loopDepth <= numCommonLoops + 1);
if (loopDepth > numCommonLoops &&
diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp
index a45c5ffdf5e..d4b8a05dbf8 100644
--- a/mlir/lib/Analysis/AffineStructures.cpp
+++ b/mlir/lib/Analysis/AffineStructures.cpp
@@ -24,8 +24,8 @@
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Instructions.h"
#include "mlir/IR/IntegerSet.h"
-#include "mlir/IR/Statements.h"
#include "mlir/Support/MathExtras.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/Support/Debug.h"
@@ -1248,22 +1248,22 @@ void FlatAffineConstraints::setDimSymbolSeparation(unsigned newSymbolCount) {
numSymbols = newSymbolCount;
}
-bool FlatAffineConstraints::addForStmtDomain(const ForStmt &forStmt) {
+bool FlatAffineConstraints::addForInstDomain(const ForInst &forInst) {
unsigned pos;
// Pre-condition for this method.
- if (!findId(forStmt, &pos)) {
+ if (!findId(forInst, &pos)) {
assert(0 && "Value not found");
return false;
}
- if (forStmt.getStep() != 1)
+ if (forInst.getStep() != 1)
LLVM_DEBUG(llvm::dbgs()
<< "Domain conservative: non-unit stride not handled\n");
// Adds a lower or upper bound when the bounds aren't constant.
auto addLowerOrUpperBound = [&](bool lower) -> bool {
- auto operands = lower ? forStmt.getLowerBoundOperands()
- : forStmt.getUpperBoundOperands();
+ auto operands = lower ? forInst.getLowerBoundOperands()
+ : forInst.getUpperBoundOperands();
for (const auto &operand : operands) {
unsigned loc;
if (!findId(*operand, &loc)) {
@@ -1271,8 +1271,8 @@ bool FlatAffineConstraints::addForStmtDomain(const ForStmt &forStmt) {
addSymbolId(getNumSymbolIds(), const_cast<Value *>(operand));
loc = getNumDimIds() + getNumSymbolIds() - 1;
// Check if the symbol is a constant.
- if (auto *opStmt = operand->getDefiningInst()) {
- if (auto constOp = opStmt->dyn_cast<ConstantIndexOp>()) {
+ if (auto *opInst = operand->getDefiningInst()) {
+ if (auto constOp = opInst->dyn_cast<ConstantIndexOp>()) {
setIdToConstant(*operand, constOp->getValue());
}
}
@@ -1292,7 +1292,7 @@ bool FlatAffineConstraints::addForStmtDomain(const ForStmt &forStmt) {
}
auto boundMap =
- lower ? forStmt.getLowerBoundMap() : forStmt.getUpperBoundMap();
+ lower ? forInst.getLowerBoundMap() : forInst.getUpperBoundMap();
FlatAffineConstraints localVarCst;
std::vector<SmallVector<int64_t, 8>> flatExprs;
@@ -1322,16 +1322,16 @@ bool FlatAffineConstraints::addForStmtDomain(const ForStmt &forStmt) {
return true;
};
- if (forStmt.hasConstantLowerBound()) {
- addConstantLowerBound(pos, forStmt.getConstantLowerBound());
+ if (forInst.hasConstantLowerBound()) {
+ addConstantLowerBound(pos, forInst.getConstantLowerBound());
} else {
// Non-constant lower bound case.
if (!addLowerOrUpperBound(/*lower=*/true))
return false;
}
- if (forStmt.hasConstantUpperBound()) {
- addConstantUpperBound(pos, forStmt.getConstantUpperBound() - 1);
+ if (forInst.hasConstantUpperBound()) {
+ addConstantUpperBound(pos, forInst.getConstantUpperBound() - 1);
return true;
}
// Non-constant upper bound case.
diff --git a/mlir/lib/Analysis/Dominance.cpp b/mlir/lib/Analysis/Dominance.cpp
index 0c8db07dbb4..4ee1b393068 100644
--- a/mlir/lib/Analysis/Dominance.cpp
+++ b/mlir/lib/Analysis/Dominance.cpp
@@ -21,7 +21,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/Dominance.h"
-#include "mlir/IR/Statements.h"
+#include "mlir/IR/Instructions.h"
#include "llvm/Support/GenericDomTreeConstruction.h"
using namespace mlir;
diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp
index dd14f38df55..b66b665c563 100644
--- a/mlir/lib/Analysis/LoopAnalysis.cpp
+++ b/mlir/lib/Analysis/LoopAnalysis.cpp
@@ -27,7 +27,7 @@
#include "mlir/Analysis/VectorAnalysis.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/Statements.h"
+#include "mlir/IR/Instructions.h"
#include "mlir/StandardOps/StandardOps.h"
#include "mlir/SuperVectorOps/SuperVectorOps.h"
#include "mlir/Support/Functional.h"
@@ -42,27 +42,27 @@ using namespace mlir;
/// Returns the trip count of the loop as an affine expression if the latter is
/// expressible as an affine expression, and nullptr otherwise. The trip count
/// expression is simplified before returning.
-AffineExpr mlir::getTripCountExpr(const ForStmt &forStmt) {
+AffineExpr mlir::getTripCountExpr(const ForInst &forInst) {
// upper_bound - lower_bound
int64_t loopSpan;
- int64_t step = forStmt.getStep();
- auto *context = forStmt.getContext();
+ int64_t step = forInst.getStep();
+ auto *context = forInst.getContext();
- if (forStmt.hasConstantBounds()) {
- int64_t lb = forStmt.getConstantLowerBound();
- int64_t ub = forStmt.getConstantUpperBound();
+ if (forInst.hasConstantBounds()) {
+ int64_t lb = forInst.getConstantLowerBound();
+ int64_t ub = forInst.getConstantUpperBound();
loopSpan = ub - lb;
} else {
- auto lbMap = forStmt.getLowerBoundMap();
- auto ubMap = forStmt.getUpperBoundMap();
+ auto lbMap = forInst.getLowerBoundMap();
+ auto ubMap = forInst.getUpperBoundMap();
// TODO(bondhugula): handle max/min of multiple expressions.
if (lbMap.getNumResults() != 1 || ubMap.getNumResults() != 1)
return nullptr;
// TODO(bondhugula): handle bounds with different operands.
// Bounds have different operands, unhandled for now.
- if (!forStmt.matchingBoundOperandList())
+ if (!forInst.matchingBoundOperandList())
return nullptr;
// ub_expr - lb_expr
@@ -88,8 +88,8 @@ AffineExpr mlir::getTripCountExpr(const ForStmt &forStmt) {
/// Returns the trip count of the loop if it's a constant, None otherwise. This
/// method uses affine expression analysis (in turn using getTripCount) and is
/// able to determine constant trip count in non-trivial cases.
-llvm::Optional<uint64_t> mlir::getConstantTripCount(const ForStmt &forStmt) {
- auto tripCountExpr = getTripCountExpr(forStmt);
+llvm::Optional<uint64_t> mlir::getConstantTripCount(const ForInst &forInst) {
+ auto tripCountExpr = getTripCountExpr(forInst);
if (!tripCountExpr)
return None;
@@ -103,8 +103,8 @@ llvm::Optional<uint64_t> mlir::getConstantTripCount(const ForStmt &forStmt) {
/// Returns the greatest known integral divisor of the trip count. Affine
/// expression analysis is used (indirectly through getTripCount), and
/// this method is thus able to determine non-trivial divisors.
-uint64_t mlir::getLargestDivisorOfTripCount(const ForStmt &forStmt) {
- auto tripCountExpr = getTripCountExpr(forStmt);
+uint64_t mlir::getLargestDivisorOfTripCount(const ForInst &forInst) {
+ auto tripCountExpr = getTripCountExpr(forInst);
if (!tripCountExpr)
return 1;
@@ -125,7 +125,7 @@ uint64_t mlir::getLargestDivisorOfTripCount(const ForStmt &forStmt) {
}
bool mlir::isAccessInvariant(const Value &iv, const Value &index) {
- assert(isa<ForStmt>(iv) && "iv must be a ForStmt");
+ assert(isa<ForInst>(iv) && "iv must be a ForInst");
assert(index.getType().isa<IndexType>() && "index must be of IndexType");
SmallVector<OperationInst *, 4> affineApplyOps;
getReachableAffineApplyOps({const_cast<Value *>(&index)}, affineApplyOps);
@@ -172,7 +172,7 @@ mlir::getInvariantAccesses(const Value &iv,
}
/// Given:
-/// 1. an induction variable `iv` of type ForStmt;
+/// 1. an induction variable `iv` of type ForInst;
/// 2. a `memoryOp` of type const LoadOp& or const StoreOp&;
/// 3. the index of the `fastestVaryingDim` along which to check;
/// determines whether `memoryOp`[`fastestVaryingDim`] is a contiguous access
@@ -233,37 +233,37 @@ static bool isVectorElement(LoadOrStoreOpPointer memoryOp) {
return memRefType.getElementType().template isa<VectorType>();
}
-static bool isVectorTransferReadOrWrite(const Statement &stmt) {
- const auto *opStmt = cast<OperationInst>(&stmt);
- return opStmt->isa<VectorTransferReadOp>() ||
- opStmt->isa<VectorTransferWriteOp>();
+static bool isVectorTransferReadOrWrite(const Instruction &inst) {
+ const auto *opInst = cast<OperationInst>(&inst);
+ return opInst->isa<VectorTransferReadOp>() ||
+ opInst->isa<VectorTransferWriteOp>();
}
-using VectorizableStmtFun =
- std::function<bool(const ForStmt &, const OperationInst &)>;
+using VectorizableInstFun =
+ std::function<bool(const ForInst &, const OperationInst &)>;
-static bool isVectorizableLoopWithCond(const ForStmt &loop,
- VectorizableStmtFun isVectorizableStmt) {
+static bool isVectorizableLoopWithCond(const ForInst &loop,
+ VectorizableInstFun isVectorizableInst) {
if (!matcher::isParallelLoop(loop) && !matcher::isReductionLoop(loop)) {
return false;
}
// No vectorization across conditionals for now.
auto conditionals = matcher::If();
- auto *forStmt = const_cast<ForStmt *>(&loop);
- auto conditionalsMatched = conditionals.match(forStmt);
+ auto *forInst = const_cast<ForInst *>(&loop);
+ auto conditionalsMatched = conditionals.match(forInst);
if (!conditionalsMatched.empty()) {
return false;
}
auto vectorTransfers = matcher::Op(isVectorTransferReadOrWrite);
- auto vectorTransfersMatched = vectorTransfers.match(forStmt);
+ auto vectorTransfersMatched = vectorTransfers.match(forInst);
if (!vectorTransfersMatched.empty()) {
return false;
}
auto loadAndStores = matcher::Op(matcher::isLoadOrStore);
- auto loadAndStoresMatched = loadAndStores.match(forStmt);
+ auto loadAndStoresMatched = loadAndStores.match(forInst);
for (auto ls : loadAndStoresMatched) {
auto *op = cast<OperationInst>(ls.first);
auto load = op->dyn_cast<LoadOp>();
@@ -275,7 +275,7 @@ static bool isVectorizableLoopWithCond(const ForStmt &loop,
if (vector) {
return false;
}
- if (!isVectorizableStmt(loop, *op)) {
+ if (!isVectorizableInst(loop, *op)) {
return false;
}
}
@@ -283,9 +283,9 @@ static bool isVectorizableLoopWithCond(const ForStmt &loop,
}
bool mlir::isVectorizableLoopAlongFastestVaryingMemRefDim(
- const ForStmt &loop, unsigned fastestVaryingDim) {
- VectorizableStmtFun fun(
- [fastestVaryingDim](const ForStmt &loop, const OperationInst &op) {
+ const ForInst &loop, unsigned fastestVaryingDim) {
+ VectorizableInstFun fun(
+ [fastestVaryingDim](const ForInst &loop, const OperationInst &op) {
auto load = op.dyn_cast<LoadOp>();
auto store = op.dyn_cast<StoreOp>();
return load ? isContiguousAccess(loop, *load, fastestVaryingDim)
@@ -294,37 +294,36 @@ bool mlir::isVectorizableLoopAlongFastestVaryingMemRefDim(
return isVectorizableLoopWithCond(loop, fun);
}
-bool mlir::isVectorizableLoop(const ForStmt &loop) {
- VectorizableStmtFun fun(
+bool mlir::isVectorizableLoop(const ForInst &loop) {
+ VectorizableInstFun fun(
// TODO: implement me
- [](const ForStmt &loop, const OperationInst &op) { return true; });
+ [](const ForInst &loop, const OperationInst &op) { return true; });
return isVectorizableLoopWithCond(loop, fun);
}
-/// Checks whether SSA dominance would be violated if a for stmt's body
-/// statements are shifted by the specified shifts. This method checks if a
+/// Checks whether SSA dominance would be violated if a for inst's body
+/// instructions are shifted by the specified shifts. This method checks if a
/// 'def' and all its uses have the same shift factor.
// TODO(mlir-team): extend this to check for memory-based dependence
// violation when we have the support.
-bool mlir::isStmtwiseShiftValid(const ForStmt &forStmt,
+bool mlir::isInstwiseShiftValid(const ForInst &forInst,
ArrayRef<uint64_t> shifts) {
- auto *forBody = forStmt.getBody();
+ auto *forBody = forInst.getBody();
assert(shifts.size() == forBody->getInstructions().size());
unsigned s = 0;
- for (const auto &stmt : *forBody) {
- // A for or if stmt does not produce any def/results (that are used
+ for (const auto &inst : *forBody) {
+ // A for or if inst does not produce any def/results (that are used
// outside).
- if (const auto *opStmt = dyn_cast<OperationInst>(&stmt)) {
- for (unsigned i = 0, e = opStmt->getNumResults(); i < e; ++i) {
- const Value *result = opStmt->getResult(i);
+ if (const auto *opInst = dyn_cast<OperationInst>(&inst)) {
+ for (unsigned i = 0, e = opInst->getNumResults(); i < e; ++i) {
+ const Value *result = opInst->getResult(i);
for (const InstOperand &use : result->getUses()) {
- // If an ancestor statement doesn't lie in the block of forStmt, there
- // is no shift to check.
- // This is a naive way. If performance becomes an issue, a map can
- // be used to store 'shifts' - to look up the shift for a statement in
- // constant time.
- if (auto *ancStmt = forBody->findAncestorInstInBlock(*use.getOwner()))
- if (shifts[s] != shifts[forBody->findInstPositionInBlock(*ancStmt)])
+ // If an ancestor instruction doesn't lie in the block of forInst,
+ // there is no shift to check. This is a naive way. If performance
+ // becomes an issue, a map can be used to store 'shifts' - to look up
+ // the shift for a instruction in constant time.
+ if (auto *ancInst = forBody->findAncestorInstInBlock(*use.getOwner()))
+ if (shifts[s] != shifts[forBody->findInstPositionInBlock(*ancInst)])
return false;
}
}
diff --git a/mlir/lib/Analysis/MLFunctionMatcher.cpp b/mlir/lib/Analysis/MLFunctionMatcher.cpp
index 12ce8481516..5bb4548e670 100644
--- a/mlir/lib/Analysis/MLFunctionMatcher.cpp
+++ b/mlir/lib/Analysis/MLFunctionMatcher.cpp
@@ -31,29 +31,29 @@ struct MLFunctionMatchesStorage {
/// Underlying storage for MLFunctionMatcher.
struct MLFunctionMatcherStorage {
- MLFunctionMatcherStorage(Statement::Kind k,
+ MLFunctionMatcherStorage(Instruction::Kind k,
MutableArrayRef<MLFunctionMatcher> c,
- FilterFunctionType filter, Statement *skip)
+ FilterFunctionType filter, Instruction *skip)
: kind(k), childrenMLFunctionMatchers(c.begin(), c.end()), filter(filter),
skip(skip) {}
- Statement::Kind kind;
+ Instruction::Kind kind;
SmallVector<MLFunctionMatcher, 4> childrenMLFunctionMatchers;
FilterFunctionType filter;
/// skip is needed so that we can implement match without switching on the
- /// type of the Statement.
+ /// type of the Instruction.
/// The idea is that a MLFunctionMatcher first checks if it matches locally
/// and then recursively applies its children matchers to its elem->children.
- /// Since we want to rely on the StmtWalker impl rather than duplicate its
+ /// Since we want to rely on the InstWalker impl rather than duplicate its
/// the logic, we allow an off-by-one traversal to account for the fact that
/// we write:
///
- /// void match(Statement *elem) {
+ /// void match(Instruction *elem) {
/// for (auto &c : getChildrenMLFunctionMatchers()) {
/// MLFunctionMatcher childMLFunctionMatcher(...);
/// ^~~~ Needs off-by-one skip.
///
- Statement *skip;
+ Instruction *skip;
};
} // end namespace mlir
@@ -65,12 +65,12 @@ llvm::BumpPtrAllocator *&MLFunctionMatches::allocator() {
return allocator;
}
-void MLFunctionMatches::append(Statement *stmt, MLFunctionMatches children) {
+void MLFunctionMatches::append(Instruction *inst, MLFunctionMatches children) {
if (!storage) {
storage = allocator()->Allocate<MLFunctionMatchesStorage>();
- new (storage) MLFunctionMatchesStorage(std::make_pair(stmt, children));
+ new (storage) MLFunctionMatchesStorage(std::make_pair(inst, children));
} else {
- storage->matches.push_back(std::make_pair(stmt, children));
+ storage->matches.push_back(std::make_pair(inst, children));
}
}
MLFunctionMatches::iterator MLFunctionMatches::begin() {
@@ -98,10 +98,10 @@ MLFunctionMatches MLFunctionMatcher::match(Function *function) {
return matches;
}
-/// Calls walk on `statement`.
-MLFunctionMatches MLFunctionMatcher::match(Statement *statement) {
+/// Calls walk on `instruction`.
+MLFunctionMatches MLFunctionMatcher::match(Instruction *instruction) {
assert(!matches && "MLFunctionMatcher already matched!");
- this->walkPostOrder(statement);
+ this->walkPostOrder(instruction);
return matches;
}
@@ -117,17 +117,17 @@ unsigned MLFunctionMatcher::getDepth() {
return depth + 1;
}
-/// Matches a single statement in the following way:
-/// 1. checks the kind of statement against the matcher, if different then
+/// Matches a single instruction in the following way:
+/// 1. checks the kind of instruction against the matcher, if different then
/// there is no match;
-/// 2. calls the customizable filter function to refine the single statement
+/// 2. calls the customizable filter function to refine the single instruction
/// match with extra semantic constraints;
/// 3. if all is good, recursivey matches the children patterns;
-/// 4. if all children match then the single statement matches too and is
+/// 4. if all children match then the single instruction matches too and is
/// appended to the list of matches;
/// 5. TODO(ntv) Optionally applies actions (lambda), in which case we will
/// want to traverse in post-order DFS to avoid invalidating iterators.
-void MLFunctionMatcher::matchOne(Statement *elem) {
+void MLFunctionMatcher::matchOne(Instruction *elem) {
if (storage->skip == elem) {
return;
}
@@ -159,7 +159,8 @@ llvm::BumpPtrAllocator *&MLFunctionMatcher::allocator() {
return allocator;
}
-MLFunctionMatcher::MLFunctionMatcher(Statement::Kind k, MLFunctionMatcher child,
+MLFunctionMatcher::MLFunctionMatcher(Instruction::Kind k,
+ MLFunctionMatcher child,
FilterFunctionType filter)
: storage(allocator()->Allocate<MLFunctionMatcherStorage>()) {
// Initialize with placement new.
@@ -168,7 +169,7 @@ MLFunctionMatcher::MLFunctionMatcher(Statement::Kind k, MLFunctionMatcher child,
}
MLFunctionMatcher::MLFunctionMatcher(
- Statement::Kind k, MutableArrayRef<MLFunctionMatcher> children,
+ Instruction::Kind k, MutableArrayRef<MLFunctionMatcher> children,
FilterFunctionType filter)
: storage(allocator()->Allocate<MLFunctionMatcherStorage>()) {
// Initialize with placement new.
@@ -178,14 +179,14 @@ MLFunctionMatcher::MLFunctionMatcher(
MLFunctionMatcher
MLFunctionMatcher::forkMLFunctionMatcherAt(MLFunctionMatcher tmpl,
- Statement *stmt) {
+ Instruction *inst) {
MLFunctionMatcher res(tmpl.getKind(), tmpl.getChildrenMLFunctionMatchers(),
tmpl.getFilterFunction());
- res.storage->skip = stmt;
+ res.storage->skip = inst;
return res;
}
-Statement::Kind MLFunctionMatcher::getKind() { return storage->kind; }
+Instruction::Kind MLFunctionMatcher::getKind() { return storage->kind; }
MutableArrayRef<MLFunctionMatcher>
MLFunctionMatcher::getChildrenMLFunctionMatchers() {
@@ -200,54 +201,55 @@ namespace mlir {
namespace matcher {
MLFunctionMatcher Op(FilterFunctionType filter) {
- return MLFunctionMatcher(Statement::Kind::OperationInst, {}, filter);
+ return MLFunctionMatcher(Instruction::Kind::OperationInst, {}, filter);
}
MLFunctionMatcher If(MLFunctionMatcher child) {
- return MLFunctionMatcher(Statement::Kind::If, child, defaultFilterFunction);
+ return MLFunctionMatcher(Instruction::Kind::If, child, defaultFilterFunction);
}
MLFunctionMatcher If(FilterFunctionType filter, MLFunctionMatcher child) {
- return MLFunctionMatcher(Statement::Kind::If, child, filter);
+ return MLFunctionMatcher(Instruction::Kind::If, child, filter);
}
MLFunctionMatcher If(MutableArrayRef<MLFunctionMatcher> children) {
- return MLFunctionMatcher(Statement::Kind::If, children,
+ return MLFunctionMatcher(Instruction::Kind::If, children,
defaultFilterFunction);
}
MLFunctionMatcher If(FilterFunctionType filter,
MutableArrayRef<MLFunctionMatcher> children) {
- return MLFunctionMatcher(Statement::Kind::If, children, filter);
+ return MLFunctionMatcher(Instruction::Kind::If, children, filter);
}
MLFunctionMatcher For(MLFunctionMatcher child) {
- return MLFunctionMatcher(Statement::Kind::For, child, defaultFilterFunction);
+ return MLFunctionMatcher(Instruction::Kind::For, child,
+ defaultFilterFunction);
}
MLFunctionMatcher For(FilterFunctionType filter, MLFunctionMatcher child) {
- return MLFunctionMatcher(Statement::Kind::For, child, filter);
+ return MLFunctionMatcher(Instruction::Kind::For, child, filter);
}
MLFunctionMatcher For(MutableArrayRef<MLFunctionMatcher> children) {
- return MLFunctionMatcher(Statement::Kind::For, children,
+ return MLFunctionMatcher(Instruction::Kind::For, children,
defaultFilterFunction);
}
MLFunctionMatcher For(FilterFunctionType filter,
MutableArrayRef<MLFunctionMatcher> children) {
- return MLFunctionMatcher(Statement::Kind::For, children, filter);
+ return MLFunctionMatcher(Instruction::Kind::For, children, filter);
}
// TODO(ntv): parallel annotation on loops.
-bool isParallelLoop(const Statement &stmt) {
- const auto *loop = cast<ForStmt>(&stmt);
+bool isParallelLoop(const Instruction &inst) {
+ const auto *loop = cast<ForInst>(&inst);
return (void *)loop || true; // loop->isParallel();
};
// TODO(ntv): reduction annotation on loops.
-bool isReductionLoop(const Statement &stmt) {
- const auto *loop = cast<ForStmt>(&stmt);
+bool isReductionLoop(const Instruction &inst) {
+ const auto *loop = cast<ForInst>(&inst);
return (void *)loop || true; // loop->isReduction();
};
-bool isLoadOrStore(const Statement &stmt) {
- const auto *opStmt = dyn_cast<OperationInst>(&stmt);
- return opStmt && (opStmt->isa<LoadOp>() || opStmt->isa<StoreOp>());
+bool isLoadOrStore(const Instruction &inst) {
+ const auto *opInst = dyn_cast<OperationInst>(&inst);
+ return opInst && (opInst->isa<LoadOp>() || opInst->isa<StoreOp>());
};
} // end namespace matcher
diff --git a/mlir/lib/Analysis/MemRefBoundCheck.cpp b/mlir/lib/Analysis/MemRefBoundCheck.cpp
index ad935faf05d..e8b668892b8 100644
--- a/mlir/lib/Analysis/MemRefBoundCheck.cpp
+++ b/mlir/lib/Analysis/MemRefBoundCheck.cpp
@@ -26,7 +26,7 @@
#include "mlir/Analysis/Utils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/StmtVisitor.h"
+#include "mlir/IR/InstVisitor.h"
#include "mlir/Pass.h"
#include "mlir/StandardOps/StandardOps.h"
#include "llvm/Support/Debug.h"
@@ -38,14 +38,14 @@ using namespace mlir;
namespace {
/// Checks for out of bound memef access subscripts..
-struct MemRefBoundCheck : public FunctionPass, StmtWalker<MemRefBoundCheck> {
+struct MemRefBoundCheck : public FunctionPass, InstWalker<MemRefBoundCheck> {
explicit MemRefBoundCheck() : FunctionPass(&MemRefBoundCheck::passID) {}
PassResult runOnMLFunction(Function *f) override;
// Not applicable to CFG functions.
PassResult runOnCFGFunction(Function *f) override { return success(); }
- void visitOperationInst(OperationInst *opStmt);
+ void visitOperationInst(OperationInst *opInst);
static char passID;
};
@@ -58,10 +58,10 @@ FunctionPass *mlir::createMemRefBoundCheckPass() {
return new MemRefBoundCheck();
}
-void MemRefBoundCheck::visitOperationInst(OperationInst *opStmt) {
- if (auto loadOp = opStmt->dyn_cast<LoadOp>()) {
+void MemRefBoundCheck::visitOperationInst(OperationInst *opInst) {
+ if (auto loadOp = opInst->dyn_cast<LoadOp>()) {
boundCheckLoadOrStoreOp(loadOp);
- } else if (auto storeOp = opStmt->dyn_cast<StoreOp>()) {
+ } else if (auto storeOp = opInst->dyn_cast<StoreOp>()) {
boundCheckLoadOrStoreOp(storeOp);
}
// TODO(bondhugula): do this for DMA ops as well.
diff --git a/mlir/lib/Analysis/MemRefDependenceCheck.cpp b/mlir/lib/Analysis/MemRefDependenceCheck.cpp
index bb668f78624..8391f15b6d3 100644
--- a/mlir/lib/Analysis/MemRefDependenceCheck.cpp
+++ b/mlir/lib/Analysis/MemRefDependenceCheck.cpp
@@ -25,7 +25,7 @@
#include "mlir/Analysis/Utils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/StmtVisitor.h"
+#include "mlir/IR/InstVisitor.h"
#include "mlir/Pass.h"
#include "mlir/StandardOps/StandardOps.h"
#include "llvm/Support/Debug.h"
@@ -39,7 +39,7 @@ namespace {
// TODO(andydavis) Add common surrounding loop depth-wise dependence checks.
/// Checks dependences between all pairs of memref accesses in a Function.
struct MemRefDependenceCheck : public FunctionPass,
- StmtWalker<MemRefDependenceCheck> {
+ InstWalker<MemRefDependenceCheck> {
SmallVector<OperationInst *, 4> loadsAndStores;
explicit MemRefDependenceCheck()
: FunctionPass(&MemRefDependenceCheck::passID) {}
@@ -48,9 +48,9 @@ struct MemRefDependenceCheck : public FunctionPass,
// Not applicable to CFG functions.
PassResult runOnCFGFunction(Function *f) override { return success(); }
- void visitOperationInst(OperationInst *opStmt) {
- if (opStmt->isa<LoadOp>() || opStmt->isa<StoreOp>()) {
- loadsAndStores.push_back(opStmt);
+ void visitOperationInst(OperationInst *opInst) {
+ if (opInst->isa<LoadOp>() || opInst->isa<StoreOp>()) {
+ loadsAndStores.push_back(opInst);
}
}
static char passID;
@@ -74,17 +74,17 @@ static void addMemRefAccessIndices(
}
}
-// Populates 'access' with memref, indices and opstmt from 'loadOrStoreOpStmt'.
-static void getMemRefAccess(const OperationInst *loadOrStoreOpStmt,
+// Populates 'access' with memref, indices and opinst from 'loadOrStoreOpInst'.
+static void getMemRefAccess(const OperationInst *loadOrStoreOpInst,
MemRefAccess *access) {
- access->opStmt = loadOrStoreOpStmt;
- if (auto loadOp = loadOrStoreOpStmt->dyn_cast<LoadOp>()) {
+ access->opInst = loadOrStoreOpInst;
+ if (auto loadOp = loadOrStoreOpInst->dyn_cast<LoadOp>()) {
access->memref = loadOp->getMemRef();
addMemRefAccessIndices(loadOp->getIndices(), loadOp->getMemRefType(),
access);
} else {
- assert(loadOrStoreOpStmt->isa<StoreOp>());
- auto storeOp = loadOrStoreOpStmt->dyn_cast<StoreOp>();
+ assert(loadOrStoreOpInst->isa<StoreOp>());
+ auto storeOp = loadOrStoreOpInst->dyn_cast<StoreOp>();
access->memref = storeOp->getMemRef();
addMemRefAccessIndices(storeOp->getIndices(), storeOp->getMemRefType(),
access);
@@ -93,8 +93,8 @@ static void getMemRefAccess(const OperationInst *loadOrStoreOpStmt,
// Returns the number of surrounding loops common to 'loopsA' and 'loopsB',
// where each lists loops from outer-most to inner-most in loop nest.
-static unsigned getNumCommonSurroundingLoops(ArrayRef<const ForStmt *> loopsA,
- ArrayRef<const ForStmt *> loopsB) {
+static unsigned getNumCommonSurroundingLoops(ArrayRef<const ForInst *> loopsA,
+ ArrayRef<const ForInst *> loopsB) {
unsigned minNumLoops = std::min(loopsA.size(), loopsB.size());
unsigned numCommonLoops = 0;
for (unsigned i = 0; i < minNumLoops; ++i) {
@@ -133,18 +133,18 @@ getDirectionVectorStr(bool ret, unsigned numCommonLoops, unsigned loopNestDepth,
// the source access.
static void checkDependences(ArrayRef<OperationInst *> loadsAndStores) {
for (unsigned i = 0, e = loadsAndStores.size(); i < e; ++i) {
- auto *srcOpStmt = loadsAndStores[i];
+ auto *srcOpInst = loadsAndStores[i];
MemRefAccess srcAccess;
- getMemRefAccess(srcOpStmt, &srcAccess);
- SmallVector<ForStmt *, 4> srcLoops;
- getLoopIVs(*srcOpStmt, &srcLoops);
+ getMemRefAccess(srcOpInst, &srcAccess);
+ SmallVector<ForInst *, 4> srcLoops;
+ getLoopIVs(*srcOpInst, &srcLoops);
for (unsigned j = 0; j < e; ++j) {
- auto *dstOpStmt = loadsAndStores[j];
+ auto *dstOpInst = loadsAndStores[j];
MemRefAccess dstAccess;
- getMemRefAccess(dstOpStmt, &dstAccess);
+ getMemRefAccess(dstOpInst, &dstAccess);
- SmallVector<ForStmt *, 4> dstLoops;
- getLoopIVs(*dstOpStmt, &dstLoops);
+ SmallVector<ForInst *, 4> dstLoops;
+ getLoopIVs(*dstOpInst, &dstLoops);
unsigned numCommonLoops =
getNumCommonSurroundingLoops(srcLoops, dstLoops);
for (unsigned d = 1; d <= numCommonLoops + 1; ++d) {
@@ -156,7 +156,7 @@ static void checkDependences(ArrayRef<OperationInst *> loadsAndStores) {
// TODO(andydavis) Print dependence type (i.e. RAW, etc) and print
// distance vectors as: ([2, 3], [0, 10]). Also, shorten distance
// vectors from ([1, 1], [3, 3]) to (1, 3).
- srcOpStmt->emitNote(
+ srcOpInst->emitNote(
"dependence from " + Twine(i) + " to " + Twine(j) + " at depth " +
Twine(d) + " = " +
getDirectionVectorStr(ret, numCommonLoops, d, dependenceComponents)
diff --git a/mlir/lib/Analysis/OpStats.cpp b/mlir/lib/Analysis/OpStats.cpp
index f4c509a5132..07edb13d1a3 100644
--- a/mlir/lib/Analysis/OpStats.cpp
+++ b/mlir/lib/Analysis/OpStats.cpp
@@ -16,9 +16,9 @@
// =============================================================================
#include "mlir/IR/Function.h"
+#include "mlir/IR/InstVisitor.h"
+#include "mlir/IR/Instructions.h"
#include "mlir/IR/OperationSupport.h"
-#include "mlir/IR/Statements.h"
-#include "mlir/IR/StmtVisitor.h"
#include "mlir/Pass.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/Support/raw_ostream.h"
@@ -26,7 +26,7 @@
using namespace mlir;
namespace {
-struct PrintOpStatsPass : public FunctionPass, StmtWalker<PrintOpStatsPass> {
+struct PrintOpStatsPass : public FunctionPass, InstWalker<PrintOpStatsPass> {
explicit PrintOpStatsPass(llvm::raw_ostream &os = llvm::errs())
: FunctionPass(&PrintOpStatsPass::passID), os(os) {}
@@ -38,7 +38,7 @@ struct PrintOpStatsPass : public FunctionPass, StmtWalker<PrintOpStatsPass> {
// Process ML functions and operation statments in ML functions.
PassResult runOnMLFunction(Function *function) override;
- void visitOperationInst(OperationInst *stmt);
+ void visitOperationInst(OperationInst *inst);
// Print summary of op stats.
void printSummary();
@@ -69,8 +69,8 @@ PassResult PrintOpStatsPass::runOnCFGFunction(Function *function) {
return success();
}
-void PrintOpStatsPass::visitOperationInst(OperationInst *stmt) {
- ++opCount[stmt->getName().getStringRef()];
+void PrintOpStatsPass::visitOperationInst(OperationInst *inst) {
+ ++opCount[inst->getName().getStringRef()];
}
PassResult PrintOpStatsPass::runOnMLFunction(Function *function) {
diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp
index 393d7c59de0..a8cec771f0d 100644
--- a/mlir/lib/Analysis/SliceAnalysis.cpp
+++ b/mlir/lib/Analysis/SliceAnalysis.cpp
@@ -22,7 +22,7 @@
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Analysis/VectorAnalysis.h"
#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/Statements.h"
+#include "mlir/IR/Instructions.h"
#include "mlir/Support/Functional.h"
#include "mlir/Support/STLExtras.h"
@@ -38,36 +38,36 @@ using namespace mlir;
using llvm::DenseSet;
using llvm::SetVector;
-void mlir::getForwardSlice(Statement *stmt,
- SetVector<Statement *> *forwardSlice,
+void mlir::getForwardSlice(Instruction *inst,
+ SetVector<Instruction *> *forwardSlice,
TransitiveFilter filter, bool topLevel) {
- if (!stmt) {
+ if (!inst) {
return;
}
// Evaluate whether we should keep this use.
// This is useful in particular to implement scoping; i.e. return the
// transitive forwardSlice in the current scope.
- if (!filter(stmt)) {
+ if (!filter(inst)) {
return;
}
- if (auto *opStmt = dyn_cast<OperationInst>(stmt)) {
- assert(opStmt->getNumResults() <= 1 && "NYI: multiple results");
- if (opStmt->getNumResults() > 0) {
- for (auto &u : opStmt->getResult(0)->getUses()) {
- auto *ownerStmt = u.getOwner();
- if (forwardSlice->count(ownerStmt) == 0) {
- getForwardSlice(ownerStmt, forwardSlice, filter,
+ if (auto *opInst = dyn_cast<OperationInst>(inst)) {
+ assert(opInst->getNumResults() <= 1 && "NYI: multiple results");
+ if (opInst->getNumResults() > 0) {
+ for (auto &u : opInst->getResult(0)->getUses()) {
+ auto *ownerInst = u.getOwner();
+ if (forwardSlice->count(ownerInst) == 0) {
+ getForwardSlice(ownerInst, forwardSlice, filter,
/*topLevel=*/false);
}
}
}
- } else if (auto *forStmt = dyn_cast<ForStmt>(stmt)) {
- for (auto &u : forStmt->getUses()) {
- auto *ownerStmt = u.getOwner();
- if (forwardSlice->count(ownerStmt) == 0) {
- getForwardSlice(ownerStmt, forwardSlice, filter,
+ } else if (auto *forInst = dyn_cast<ForInst>(inst)) {
+ for (auto &u : forInst->getUses()) {
+ auto *ownerInst = u.getOwner();
+ if (forwardSlice->count(ownerInst) == 0) {
+ getForwardSlice(ownerInst, forwardSlice, filter,
/*topLevel=*/false);
}
}
@@ -80,61 +80,61 @@ void mlir::getForwardSlice(Statement *stmt,
// std::reverse does not work out of the box on SetVector and I want an
// in-place swap based thing (the real std::reverse, not the LLVM adapter).
// TODO(clattner): Consider adding an extra method?
- std::vector<Statement *> v(forwardSlice->takeVector());
+ std::vector<Instruction *> v(forwardSlice->takeVector());
forwardSlice->insert(v.rbegin(), v.rend());
} else {
- forwardSlice->insert(stmt);
+ forwardSlice->insert(inst);
}
}
-void mlir::getBackwardSlice(Statement *stmt,
- SetVector<Statement *> *backwardSlice,
+void mlir::getBackwardSlice(Instruction *inst,
+ SetVector<Instruction *> *backwardSlice,
TransitiveFilter filter, bool topLevel) {
- if (!stmt) {
+ if (!inst) {
return;
}
// Evaluate whether we should keep this def.
// This is useful in particular to implement scoping; i.e. return the
// transitive forwardSlice in the current scope.
- if (!filter(stmt)) {
+ if (!filter(inst)) {
return;
}
- for (auto *operand : stmt->getOperands()) {
- auto *stmt = operand->getDefiningInst();
- if (backwardSlice->count(stmt) == 0) {
- getBackwardSlice(stmt, backwardSlice, filter,
+ for (auto *operand : inst->getOperands()) {
+ auto *inst = operand->getDefiningInst();
+ if (backwardSlice->count(inst) == 0) {
+ getBackwardSlice(inst, backwardSlice, filter,
/*topLevel=*/false);
}
}
- // Don't insert the top level statement, we just queried on it and don't
+ // Don't insert the top level instruction, we just queried on it and don't
// want it in the results.
if (!topLevel) {
- backwardSlice->insert(stmt);
+ backwardSlice->insert(inst);
}
}
-SetVector<Statement *> mlir::getSlice(Statement *stmt,
- TransitiveFilter backwardFilter,
- TransitiveFilter forwardFilter) {
- SetVector<Statement *> slice;
- slice.insert(stmt);
+SetVector<Instruction *> mlir::getSlice(Instruction *inst,
+ TransitiveFilter backwardFilter,
+ TransitiveFilter forwardFilter) {
+ SetVector<Instruction *> slice;
+ slice.insert(inst);
unsigned currentIndex = 0;
- SetVector<Statement *> backwardSlice;
- SetVector<Statement *> forwardSlice;
+ SetVector<Instruction *> backwardSlice;
+ SetVector<Instruction *> forwardSlice;
while (currentIndex != slice.size()) {
- auto *currentStmt = (slice)[currentIndex];
- // Compute and insert the backwardSlice starting from currentStmt.
+ auto *currentInst = (slice)[currentIndex];
+ // Compute and insert the backwardSlice starting from currentInst.
backwardSlice.clear();
- getBackwardSlice(currentStmt, &backwardSlice, backwardFilter);
+ getBackwardSlice(currentInst, &backwardSlice, backwardFilter);
slice.insert(backwardSlice.begin(), backwardSlice.end());
- // Compute and insert the forwardSlice starting from currentStmt.
+ // Compute and insert the forwardSlice starting from currentInst.
forwardSlice.clear();
- getForwardSlice(currentStmt, &forwardSlice, forwardFilter);
+ getForwardSlice(currentInst, &forwardSlice, forwardFilter);
slice.insert(forwardSlice.begin(), forwardSlice.end());
++currentIndex;
}
@@ -144,24 +144,24 @@ SetVector<Statement *> mlir::getSlice(Statement *stmt,
namespace {
/// DFS post-order implementation that maintains a global count to work across
/// multiple invocations, to help implement topological sort on multi-root DAGs.
-/// We traverse all statements but only record the ones that appear in `toSort`
-/// for the final result.
+/// We traverse all instructions but only record the ones that appear in
+/// `toSort` for the final result.
struct DFSState {
- DFSState(const SetVector<Statement *> &set)
+ DFSState(const SetVector<Instruction *> &set)
: toSort(set), topologicalCounts(), seen() {}
- const SetVector<Statement *> &toSort;
- SmallVector<Statement *, 16> topologicalCounts;
- DenseSet<Statement *> seen;
+ const SetVector<Instruction *> &toSort;
+ SmallVector<Instruction *, 16> topologicalCounts;
+ DenseSet<Instruction *> seen;
};
} // namespace
-static void DFSPostorder(Statement *current, DFSState *state) {
- auto *opStmt = cast<OperationInst>(current);
- assert(opStmt->getNumResults() <= 1 && "NYI: multi-result");
- if (opStmt->getNumResults() > 0) {
- for (auto &u : opStmt->getResult(0)->getUses()) {
- auto *stmt = u.getOwner();
- DFSPostorder(stmt, state);
+static void DFSPostorder(Instruction *current, DFSState *state) {
+ auto *opInst = cast<OperationInst>(current);
+ assert(opInst->getNumResults() <= 1 && "NYI: multi-result");
+ if (opInst->getNumResults() > 0) {
+ for (auto &u : opInst->getResult(0)->getUses()) {
+ auto *inst = u.getOwner();
+ DFSPostorder(inst, state);
}
}
bool inserted;
@@ -175,8 +175,8 @@ static void DFSPostorder(Statement *current, DFSState *state) {
}
}
-SetVector<Statement *>
-mlir::topologicalSort(const SetVector<Statement *> &toSort) {
+SetVector<Instruction *>
+mlir::topologicalSort(const SetVector<Instruction *> &toSort) {
if (toSort.empty()) {
return toSort;
}
@@ -189,7 +189,7 @@ mlir::topologicalSort(const SetVector<Statement *> &toSort) {
}
// Reorder and return.
- SetVector<Statement *> res;
+ SetVector<Instruction *> res;
for (auto it = state.topologicalCounts.rbegin(),
eit = state.topologicalCounts.rend();
it != eit; ++it) {
diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp
index f6191418f54..a7fc5ac619e 100644
--- a/mlir/lib/Analysis/Utils.cpp
+++ b/mlir/lib/Analysis/Utils.cpp
@@ -34,8 +34,8 @@
using namespace mlir;
-/// Returns true if statement 'a' properly dominates statement b.
-bool mlir::properlyDominates(const Statement &a, const Statement &b) {
+/// Returns true if instruction 'a' properly dominates instruction b.
+bool mlir::properlyDominates(const Instruction &a, const Instruction &b) {
if (&a == &b)
return false;
@@ -64,24 +64,24 @@ bool mlir::properlyDominates(const Statement &a, const Statement &b) {
return false;
}
-/// Returns true if statement A dominates statement B.
-bool mlir::dominates(const Statement &a, const Statement &b) {
+/// Returns true if instruction A dominates instruction B.
+bool mlir::dominates(const Instruction &a, const Instruction &b) {
return &a == &b || properlyDominates(a, b);
}
-/// Populates 'loops' with IVs of the loops surrounding 'stmt' ordered from
-/// the outermost 'for' statement to the innermost one.
-void mlir::getLoopIVs(const Statement &stmt,
- SmallVectorImpl<ForStmt *> *loops) {
- auto *currStmt = stmt.getParentStmt();
- ForStmt *currForStmt;
- // Traverse up the hierarchy collecing all 'for' statement while skipping over
- // 'if' statements.
- while (currStmt && ((currForStmt = dyn_cast<ForStmt>(currStmt)) ||
- isa<IfStmt>(currStmt))) {
- if (currForStmt)
- loops->push_back(currForStmt);
- currStmt = currStmt->getParentStmt();
+/// Populates 'loops' with IVs of the loops surrounding 'inst' ordered from
+/// the outermost 'for' instruction to the innermost one.
+void mlir::getLoopIVs(const Instruction &inst,
+ SmallVectorImpl<ForInst *> *loops) {
+ auto *currInst = inst.getParentInst();
+ ForInst *currForInst;
+ // Traverse up the hierarchy collecing all 'for' instruction while skipping
+ // over 'if' instructions.
+ while (currInst && ((currForInst = dyn_cast<ForInst>(currInst)) ||
+ isa<IfInst>(currInst))) {
+ if (currForInst)
+ loops->push_back(currForInst);
+ currInst = currInst->getParentInst();
}
std::reverse(loops->begin(), loops->end());
}
@@ -129,7 +129,7 @@ Optional<int64_t> MemRefRegion::getBoundingConstantSizeAndShape(
/// Computes the memory region accessed by this memref with the region
/// represented as constraints symbolic/parameteric in 'loopDepth' loops
-/// surrounding opStmt and any additional Function symbols. Returns false if
+/// surrounding opInst and any additional Function symbols. Returns false if
/// this fails due to yet unimplemented cases.
// For example, the memref region for this load operation at loopDepth = 1 will
// be as below:
@@ -145,21 +145,21 @@ Optional<int64_t> MemRefRegion::getBoundingConstantSizeAndShape(
//
// TODO(bondhugula): extend this to any other memref dereferencing ops
// (dma_start, dma_wait).
-bool mlir::getMemRefRegion(OperationInst *opStmt, unsigned loopDepth,
+bool mlir::getMemRefRegion(OperationInst *opInst, unsigned loopDepth,
MemRefRegion *region) {
OpPointer<LoadOp> loadOp;
OpPointer<StoreOp> storeOp;
unsigned rank;
SmallVector<Value *, 4> indices;
- if ((loadOp = opStmt->dyn_cast<LoadOp>())) {
+ if ((loadOp = opInst->dyn_cast<LoadOp>())) {
rank = loadOp->getMemRefType().getRank();
for (auto *index : loadOp->getIndices()) {
indices.push_back(index);
}
region->memref = loadOp->getMemRef();
region->setWrite(false);
- } else if ((storeOp = opStmt->dyn_cast<StoreOp>())) {
+ } else if ((storeOp = opInst->dyn_cast<StoreOp>())) {
rank = storeOp->getMemRefType().getRank();
for (auto *index : storeOp->getIndices()) {
indices.push_back(index);
@@ -173,7 +173,7 @@ bool mlir::getMemRefRegion(OperationInst *opStmt, unsigned loopDepth,
// Build the constraints for this region.
FlatAffineConstraints *regionCst = region->getConstraints();
- FuncBuilder b(opStmt);
+ FuncBuilder b(opInst);
auto idMap = b.getMultiDimIdentityMap(rank);
// Initialize 'accessValueMap' and compose with reachable AffineApplyOps.
@@ -192,20 +192,20 @@ bool mlir::getMemRefRegion(OperationInst *opStmt, unsigned loopDepth,
unsigned numSymbols = accessMap.getNumSymbols();
// Add inequalties for loop lower/upper bounds.
for (unsigned i = 0; i < numDims + numSymbols; ++i) {
- if (auto *loop = dyn_cast<ForStmt>(accessValueMap.getOperand(i))) {
+ if (auto *loop = dyn_cast<ForInst>(accessValueMap.getOperand(i))) {
// Note that regionCst can now have more dimensions than accessMap if the
// bounds expressions involve outer loops or other symbols.
- // TODO(bondhugula): rewrite this to use getStmtIndexSet; this way
+ // TODO(bondhugula): rewrite this to use getInstIndexSet; this way
// conditionals will be handled when the latter supports it.
- if (!regionCst->addForStmtDomain(*loop))
+ if (!regionCst->addForInstDomain(*loop))
return false;
} else {
// Has to be a valid symbol.
auto *symbol = accessValueMap.getOperand(i);
assert(symbol->isValidSymbol());
// Check if the symbol is a constant.
- if (auto *opStmt = symbol->getDefiningInst()) {
- if (auto constOp = opStmt->dyn_cast<ConstantIndexOp>()) {
+ if (auto *opInst = symbol->getDefiningInst()) {
+ if (auto constOp = opInst->dyn_cast<ConstantIndexOp>()) {
regionCst->setIdToConstant(*symbol, constOp->getValue());
}
}
@@ -220,12 +220,12 @@ bool mlir::getMemRefRegion(OperationInst *opStmt, unsigned loopDepth,
// Eliminate any loop IVs other than the outermost 'loopDepth' IVs, on which
// this memref region is symbolic.
- SmallVector<ForStmt *, 4> outerIVs;
- getLoopIVs(*opStmt, &outerIVs);
+ SmallVector<ForInst *, 4> outerIVs;
+ getLoopIVs(*opInst, &outerIVs);
outerIVs.resize(loopDepth);
for (auto *operand : accessValueMap.getOperands()) {
- ForStmt *iv;
- if ((iv = dyn_cast<ForStmt>(operand)) &&
+ ForInst *iv;
+ if ((iv = dyn_cast<ForInst>(operand)) &&
std::find(outerIVs.begin(), outerIVs.end(), iv) == outerIVs.end()) {
regionCst->projectOut(operand);
}
@@ -282,9 +282,9 @@ bool mlir::boundCheckLoadOrStoreOp(LoadOrStoreOpPointer loadOrStoreOp,
std::is_same<LoadOrStoreOpPointer, OpPointer<StoreOp>>::value,
"function argument should be either a LoadOp or a StoreOp");
- OperationInst *opStmt = loadOrStoreOp->getInstruction();
+ OperationInst *opInst = loadOrStoreOp->getInstruction();
MemRefRegion region;
- if (!getMemRefRegion(opStmt, /*loopDepth=*/0, &region))
+ if (!getMemRefRegion(opInst, /*loopDepth=*/0, &region))
return false;
LLVM_DEBUG(llvm::dbgs() << "Memory region");
LLVM_DEBUG(region.getConstraints()->dump());
@@ -333,43 +333,43 @@ template bool mlir::boundCheckLoadOrStoreOp(OpPointer<LoadOp> loadOp,
template bool mlir::boundCheckLoadOrStoreOp(OpPointer<StoreOp> storeOp,
bool emitError);
-// Returns in 'positions' the Block positions of 'stmt' in each ancestor
-// Block from the Block containing statement, stopping at 'limitBlock'.
-static void findStmtPosition(const Statement *stmt, Block *limitBlock,
+// Returns in 'positions' the Block positions of 'inst' in each ancestor
+// Block from the Block containing instruction, stopping at 'limitBlock'.
+static void findInstPosition(const Instruction *inst, Block *limitBlock,
SmallVectorImpl<unsigned> *positions) {
- Block *block = stmt->getBlock();
+ Block *block = inst->getBlock();
while (block != limitBlock) {
- int stmtPosInBlock = block->findInstPositionInBlock(*stmt);
- assert(stmtPosInBlock >= 0);
- positions->push_back(stmtPosInBlock);
- stmt = block->getContainingInst();
- block = stmt->getBlock();
+ int instPosInBlock = block->findInstPositionInBlock(*inst);
+ assert(instPosInBlock >= 0);
+ positions->push_back(instPosInBlock);
+ inst = block->getContainingInst();
+ block = inst->getBlock();
}
std::reverse(positions->begin(), positions->end());
}
-// Returns the Statement in a possibly nested set of Blocks, where the
-// position of the statement is represented by 'positions', which has a
+// Returns the Instruction in a possibly nested set of Blocks, where the
+// position of the instruction is represented by 'positions', which has a
// Block position for each level of nesting.
-static Statement *getStmtAtPosition(ArrayRef<unsigned> positions,
- unsigned level, Block *block) {
+static Instruction *getInstAtPosition(ArrayRef<unsigned> positions,
+ unsigned level, Block *block) {
unsigned i = 0;
- for (auto &stmt : *block) {
+ for (auto &inst : *block) {
if (i != positions[level]) {
++i;
continue;
}
if (level == positions.size() - 1)
- return &stmt;
- if (auto *childForStmt = dyn_cast<ForStmt>(&stmt))
- return getStmtAtPosition(positions, level + 1, childForStmt->getBody());
+ return &inst;
+ if (auto *childForInst = dyn_cast<ForInst>(&inst))
+ return getInstAtPosition(positions, level + 1, childForInst->getBody());
- if (auto *ifStmt = dyn_cast<IfStmt>(&stmt)) {
- auto *ret = getStmtAtPosition(positions, level + 1, ifStmt->getThen());
+ if (auto *ifInst = dyn_cast<IfInst>(&inst)) {
+ auto *ret = getInstAtPosition(positions, level + 1, ifInst->getThen());
if (ret != nullptr)
return ret;
- if (auto *elseClause = ifStmt->getElse())
- return getStmtAtPosition(positions, level + 1, elseClause);
+ if (auto *elseClause = ifInst->getElse())
+ return getInstAtPosition(positions, level + 1, elseClause);
}
}
return nullptr;
@@ -379,7 +379,7 @@ static Statement *getStmtAtPosition(ArrayRef<unsigned> positions,
// dependence constraint system to create AffineMaps with which to adjust the
// loop bounds of the inserted compution slice so that they are functions of the
// loop IVs and symbols of the loops surrounding 'dstAccess'.
-ForStmt *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess,
+ForInst *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess,
MemRefAccess *dstAccess,
unsigned srcLoopDepth,
unsigned dstLoopDepth) {
@@ -390,14 +390,14 @@ ForStmt *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess,
return nullptr;
}
// Get loop nest surrounding src operation.
- SmallVector<ForStmt *, 4> srcLoopNest;
- getLoopIVs(*srcAccess->opStmt, &srcLoopNest);
+ SmallVector<ForInst *, 4> srcLoopNest;
+ getLoopIVs(*srcAccess->opInst, &srcLoopNest);
unsigned srcLoopNestSize = srcLoopNest.size();
assert(srcLoopDepth <= srcLoopNestSize);
// Get loop nest surrounding dst operation.
- SmallVector<ForStmt *, 4> dstLoopNest;
- getLoopIVs(*dstAccess->opStmt, &dstLoopNest);
+ SmallVector<ForInst *, 4> dstLoopNest;
+ getLoopIVs(*dstAccess->opInst, &dstLoopNest);
unsigned dstLoopNestSize = dstLoopNest.size();
(void)dstLoopNestSize;
assert(dstLoopDepth > 0);
@@ -425,7 +425,7 @@ ForStmt *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess,
}
SmallVector<unsigned, 2> nonZeroDimIds;
SmallVector<unsigned, 2> nonZeroSymbolIds;
- srcIvMaps[i] = cst->toAffineMapFromEq(0, 0, srcAccess->opStmt->getContext(),
+ srcIvMaps[i] = cst->toAffineMapFromEq(0, 0, srcAccess->opInst->getContext(),
&nonZeroDimIds, &nonZeroSymbolIds);
if (srcIvMaps[i] == AffineMap::Null()) {
continue;
@@ -446,23 +446,23 @@ ForStmt *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess,
// with a symbol identifiers in 'nonZeroSymbolIds'.
}
- // Find the stmt block positions of 'srcAccess->opStmt' within 'srcLoopNest'.
+ // Find the inst block positions of 'srcAccess->opInst' within 'srcLoopNest'.
SmallVector<unsigned, 4> positions;
- findStmtPosition(srcAccess->opStmt, srcLoopNest[0]->getBlock(), &positions);
+ findInstPosition(srcAccess->opInst, srcLoopNest[0]->getBlock(), &positions);
- // Clone src loop nest and insert it a the beginning of the statement block
+ // Clone src loop nest and insert it a the beginning of the instruction block
// of the loop at 'dstLoopDepth' in 'dstLoopNest'.
- auto *dstForStmt = dstLoopNest[dstLoopDepth - 1];
- FuncBuilder b(dstForStmt->getBody(), dstForStmt->getBody()->begin());
+ auto *dstForInst = dstLoopNest[dstLoopDepth - 1];
+ FuncBuilder b(dstForInst->getBody(), dstForInst->getBody()->begin());
DenseMap<const Value *, Value *> operandMap;
- auto *sliceLoopNest = cast<ForStmt>(b.clone(*srcLoopNest[0], operandMap));
-
- // Lookup stmt in cloned 'sliceLoopNest' at 'positions'.
- Statement *sliceStmt =
- getStmtAtPosition(positions, /*level=*/0, sliceLoopNest->getBody());
- // Get loop nest surrounding 'sliceStmt'.
- SmallVector<ForStmt *, 4> sliceSurroundingLoops;
- getLoopIVs(*sliceStmt, &sliceSurroundingLoops);
+ auto *sliceLoopNest = cast<ForInst>(b.clone(*srcLoopNest[0], operandMap));
+
+ // Lookup inst in cloned 'sliceLoopNest' at 'positions'.
+ Instruction *sliceInst =
+ getInstAtPosition(positions, /*level=*/0, sliceLoopNest->getBody());
+ // Get loop nest surrounding 'sliceInst'.
+ SmallVector<ForInst *, 4> sliceSurroundingLoops;
+ getLoopIVs(*sliceInst, &sliceSurroundingLoops);
unsigned sliceSurroundingLoopsSize = sliceSurroundingLoops.size();
(void)sliceSurroundingLoopsSize;
@@ -470,18 +470,18 @@ ForStmt *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess,
unsigned sliceLoopLimit = dstLoopDepth + srcLoopNestSize;
assert(sliceLoopLimit <= sliceSurroundingLoopsSize);
for (unsigned i = dstLoopDepth; i < sliceLoopLimit; ++i) {
- auto *forStmt = sliceSurroundingLoops[i];
+ auto *forInst = sliceSurroundingLoops[i];
unsigned index = i - dstLoopDepth;
AffineMap lbMap = srcIvMaps[index];
if (lbMap == AffineMap::Null())
continue;
- forStmt->setLowerBound(srcIvOperands[index], lbMap);
+ forInst->setLowerBound(srcIvOperands[index], lbMap);
// Create upper bound map with is lower bound map + 1;
assert(lbMap.getNumResults() == 1);
AffineExpr ubResultExpr = lbMap.getResult(0) + 1;
AffineMap ubMap = AffineMap::get(lbMap.getNumDims(), lbMap.getNumSymbols(),
{ubResultExpr}, {});
- forStmt->setUpperBound(srcIvOperands[index], ubMap);
+ forInst->setUpperBound(srcIvOperands[index], ubMap);
}
return sliceLoopNest;
}
diff --git a/mlir/lib/Analysis/VectorAnalysis.cpp b/mlir/lib/Analysis/VectorAnalysis.cpp
index cd9451cd5e9..e092b29a13b 100644
--- a/mlir/lib/Analysis/VectorAnalysis.cpp
+++ b/mlir/lib/Analysis/VectorAnalysis.cpp
@@ -19,7 +19,7 @@
#include "mlir/Analysis/LoopAnalysis.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/Statements.h"
+#include "mlir/IR/Instructions.h"
#include "mlir/StandardOps/StandardOps.h"
#include "mlir/SuperVectorOps/SuperVectorOps.h"
#include "mlir/Support/Functional.h"
@@ -105,7 +105,7 @@ Optional<SmallVector<unsigned, 4>> mlir::shapeRatio(VectorType superVectorType,
static AffineMap makePermutationMap(
MLIRContext *context,
llvm::iterator_range<OperationInst::operand_iterator> indices,
- const DenseMap<ForStmt *, unsigned> &enclosingLoopToVectorDim) {
+ const DenseMap<ForInst *, unsigned> &enclosingLoopToVectorDim) {
using functional::makePtrDynCaster;
using functional::map;
auto unwrappedIndices = map(makePtrDynCaster<Value, Value>(), indices);
@@ -137,10 +137,11 @@ static AffineMap makePermutationMap(
/// the specified type.
/// TODO(ntv): could also be implemented as a collect parents followed by a
/// filter and made available outside this file.
-template <typename T> static SetVector<T *> getParentsOfType(Statement *stmt) {
+template <typename T>
+static SetVector<T *> getParentsOfType(Instruction *inst) {
SetVector<T *> res;
- auto *current = stmt;
- while (auto *parent = current->getParentStmt()) {
+ auto *current = inst;
+ while (auto *parent = current->getParentInst()) {
auto *typedParent = dyn_cast<T>(parent);
if (typedParent) {
assert(res.count(typedParent) == 0 && "Already inserted");
@@ -151,34 +152,34 @@ template <typename T> static SetVector<T *> getParentsOfType(Statement *stmt) {
return res;
}
-/// Returns the enclosing ForStmt, from closest to farthest.
-static SetVector<ForStmt *> getEnclosingForStmts(Statement *stmt) {
- return getParentsOfType<ForStmt>(stmt);
+/// Returns the enclosing ForInst, from closest to farthest.
+static SetVector<ForInst *> getEnclosingforInsts(Instruction *inst) {
+ return getParentsOfType<ForInst>(inst);
}
AffineMap
-mlir::makePermutationMap(OperationInst *opStmt,
- const DenseMap<ForStmt *, unsigned> &loopToVectorDim) {
- DenseMap<ForStmt *, unsigned> enclosingLoopToVectorDim;
- auto enclosingLoops = getEnclosingForStmts(opStmt);
- for (auto *forStmt : enclosingLoops) {
- auto it = loopToVectorDim.find(forStmt);
+mlir::makePermutationMap(OperationInst *opInst,
+ const DenseMap<ForInst *, unsigned> &loopToVectorDim) {
+ DenseMap<ForInst *, unsigned> enclosingLoopToVectorDim;
+ auto enclosingLoops = getEnclosingforInsts(opInst);
+ for (auto *forInst : enclosingLoops) {
+ auto it = loopToVectorDim.find(forInst);
if (it != loopToVectorDim.end()) {
enclosingLoopToVectorDim.insert(*it);
}
}
- if (auto load = opStmt->dyn_cast<LoadOp>()) {
- return ::makePermutationMap(opStmt->getContext(), load->getIndices(),
+ if (auto load = opInst->dyn_cast<LoadOp>()) {
+ return ::makePermutationMap(opInst->getContext(), load->getIndices(),
enclosingLoopToVectorDim);
}
- auto store = opStmt->cast<StoreOp>();
- return ::makePermutationMap(opStmt->getContext(), store->getIndices(),
+ auto store = opInst->cast<StoreOp>();
+ return ::makePermutationMap(opInst->getContext(), store->getIndices(),
enclosingLoopToVectorDim);
}
-bool mlir::matcher::operatesOnStrictSuperVectors(const OperationInst &opStmt,
+bool mlir::matcher::operatesOnStrictSuperVectors(const OperationInst &opInst,
VectorType subVectorType) {
// First, extract the vector type and ditinguish between:
// a. ops that *must* lower a super-vector (i.e. vector_transfer_read,
@@ -191,20 +192,20 @@ bool mlir::matcher::operatesOnStrictSuperVectors(const OperationInst &opStmt,
/// do not have to special case. Maybe a trait, or just a method, unclear atm.
bool mustDivide = false;
VectorType superVectorType;
- if (auto read = opStmt.dyn_cast<VectorTransferReadOp>()) {
+ if (auto read = opInst.dyn_cast<VectorTransferReadOp>()) {
superVectorType = read->getResultType();
mustDivide = true;
- } else if (auto write = opStmt.dyn_cast<VectorTransferWriteOp>()) {
+ } else if (auto write = opInst.dyn_cast<VectorTransferWriteOp>()) {
superVectorType = write->getVectorType();
mustDivide = true;
- } else if (opStmt.getNumResults() == 0) {
- if (!opStmt.isa<ReturnOp>()) {
- opStmt.emitError("NYI: assuming only return statements can have 0 "
+ } else if (opInst.getNumResults() == 0) {
+ if (!opInst.isa<ReturnOp>()) {
+ opInst.emitError("NYI: assuming only return instructions can have 0 "
" results at this point");
}
return false;
- } else if (opStmt.getNumResults() == 1) {
- if (auto v = opStmt.getResult(0)->getType().dyn_cast<VectorType>()) {
+ } else if (opInst.getNumResults() == 1) {
+ if (auto v = opInst.getResult(0)->getType().dyn_cast<VectorType>()) {
superVectorType = v;
} else {
// Not a vector type.
@@ -213,7 +214,7 @@ bool mlir::matcher::operatesOnStrictSuperVectors(const OperationInst &opStmt,
} else {
// Not a vector_transfer and has more than 1 result, fail hard for now to
// wake us up when something changes.
- opStmt.emitError("NYI: statement has more than 1 result");
+ opInst.emitError("NYI: instruction has more than 1 result");
return false;
}
diff --git a/mlir/lib/Analysis/Verifier.cpp b/mlir/lib/Analysis/Verifier.cpp
index 4cad531ecaa..7217c5492a6 100644
--- a/mlir/lib/Analysis/Verifier.cpp
+++ b/mlir/lib/Analysis/Verifier.cpp
@@ -36,9 +36,9 @@
#include "mlir/Analysis/Dominance.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Function.h"
+#include "mlir/IR/InstVisitor.h"
+#include "mlir/IR/Instructions.h"
#include "mlir/IR/Module.h"
-#include "mlir/IR/Statements.h"
-#include "mlir/IR/StmtVisitor.h"
#include "llvm/ADT/ScopedHashTable.h"
#include "llvm/Support/PrettyStackTrace.h"
#include "llvm/Support/raw_ostream.h"
@@ -239,14 +239,14 @@ bool CFGFuncVerifier::verifyBlock(const Block &block) {
//===----------------------------------------------------------------------===//
namespace {
-struct MLFuncVerifier : public Verifier, public StmtWalker<MLFuncVerifier> {
+struct MLFuncVerifier : public Verifier, public InstWalker<MLFuncVerifier> {
const Function &fn;
bool hadError = false;
MLFuncVerifier(const Function &fn) : Verifier(fn), fn(fn) {}
- void visitOperationInst(OperationInst *opStmt) {
- hadError |= verifyOperation(*opStmt);
+ void visitOperationInst(OperationInst *opInst) {
+ hadError |= verifyOperation(*opInst);
}
bool verify() {
@@ -269,7 +269,7 @@ struct MLFuncVerifier : public Verifier, public StmtWalker<MLFuncVerifier> {
/// operations are properly dominated by their definitions.
bool verifyDominance();
- /// Verify that function has a return statement that matches its signature.
+ /// Verify that function has a return instruction that matches its signature.
bool verifyReturn();
};
} // end anonymous namespace
@@ -285,48 +285,48 @@ bool MLFuncVerifier::verifyDominance() {
for (auto *arg : fn.getArguments())
liveValues.insert(arg, true);
- // This recursive function walks the statement list pushing scopes onto the
+ // This recursive function walks the instruction list pushing scopes onto the
// stack as it goes, and popping them to remove them from the table.
std::function<bool(const Block &block)> walkBlock;
walkBlock = [&](const Block &block) -> bool {
HashTable::ScopeTy blockScope(liveValues);
- // The induction variable of a for statement is live within its body.
- if (auto *forStmt = dyn_cast_or_null<ForStmt>(block.getContainingInst()))
- liveValues.insert(forStmt, true);
+ // The induction variable of a for instruction is live within its body.
+ if (auto *forInst = dyn_cast_or_null<ForInst>(block.getContainingInst()))
+ liveValues.insert(forInst, true);
- for (auto &stmt : block) {
+ for (auto &inst : block) {
// Verify that each of the operands are live.
unsigned operandNo = 0;
- for (auto *opValue : stmt.getOperands()) {
+ for (auto *opValue : inst.getOperands()) {
if (!liveValues.count(opValue)) {
- stmt.emitError("operand #" + Twine(operandNo) +
+ inst.emitError("operand #" + Twine(operandNo) +
" does not dominate this use");
- if (auto *useStmt = opValue->getDefiningInst())
- useStmt->emitNote("operand defined here");
+ if (auto *useInst = opValue->getDefiningInst())
+ useInst->emitNote("operand defined here");
return true;
}
++operandNo;
}
- if (auto *opStmt = dyn_cast<OperationInst>(&stmt)) {
+ if (auto *opInst = dyn_cast<OperationInst>(&inst)) {
// Operations define values, add them to the hash table.
- for (auto *result : opStmt->getResults())
+ for (auto *result : opInst->getResults())
liveValues.insert(result, true);
continue;
}
// If this is an if or for, recursively walk the block they contain.
- if (auto *ifStmt = dyn_cast<IfStmt>(&stmt)) {
- if (walkBlock(*ifStmt->getThen()))
+ if (auto *ifInst = dyn_cast<IfInst>(&inst)) {
+ if (walkBlock(*ifInst->getThen()))
return true;
- if (auto *elseClause = ifStmt->getElse())
+ if (auto *elseClause = ifInst->getElse())
if (walkBlock(*elseClause))
return true;
}
- if (auto *forStmt = dyn_cast<ForStmt>(&stmt))
- if (walkBlock(*forStmt->getBody()))
+ if (auto *forInst = dyn_cast<ForInst>(&inst))
+ if (walkBlock(*forInst->getBody()))
return true;
}
@@ -338,13 +338,14 @@ bool MLFuncVerifier::verifyDominance() {
}
bool MLFuncVerifier::verifyReturn() {
- // TODO: fold return verification in the pass that verifies all statements.
- const char missingReturnMsg[] = "ML function must end with return statement";
+ // TODO: fold return verification in the pass that verifies all instructions.
+ const char missingReturnMsg[] =
+ "ML function must end with return instruction";
if (fn.getBody()->getInstructions().empty())
return failure(missingReturnMsg, fn);
- const auto &stmt = fn.getBody()->getInstructions().back();
- if (const auto *op = dyn_cast<OperationInst>(&stmt)) {
+ const auto &inst = fn.getBody()->getInstructions().back();
+ if (const auto *op = dyn_cast<OperationInst>(&inst)) {
if (!op->isReturn())
return failure(missingReturnMsg, fn);
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index daaaee7010c..cf822e025b8 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -25,11 +25,11 @@
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Function.h"
+#include "mlir/IR/InstVisitor.h"
+#include "mlir/IR/Instructions.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/OpImplementation.h"
-#include "mlir/IR/Statements.h"
-#include "mlir/IR/StmtVisitor.h"
#include "mlir/IR/Types.h"
#include "mlir/Support/STLExtras.h"
#include "llvm/ADT/APFloat.h"
@@ -117,10 +117,10 @@ private:
void visitExtFunction(const Function *fn);
void visitCFGFunction(const Function *fn);
void visitMLFunction(const Function *fn);
- void visitStatement(const Statement *stmt);
- void visitForStmt(const ForStmt *forStmt);
- void visitIfStmt(const IfStmt *ifStmt);
- void visitOperationInst(const OperationInst *opStmt);
+ void visitInstruction(const Instruction *inst);
+ void visitForInst(const ForInst *forInst);
+ void visitIfInst(const IfInst *ifInst);
+ void visitOperationInst(const OperationInst *opInst);
void visitType(Type type);
void visitAttribute(Attribute attr);
void visitOperation(const OperationInst *op);
@@ -184,47 +184,47 @@ void ModuleState::visitCFGFunction(const Function *fn) {
if (auto *opInst = dyn_cast<OperationInst>(&op))
visitOperation(opInst);
else {
- llvm_unreachable("IfStmt/ForStmt in a CFG Function isn't supported");
+ llvm_unreachable("IfInst/ForInst in a CFG Function isn't supported");
}
}
}
}
-void ModuleState::visitIfStmt(const IfStmt *ifStmt) {
- recordIntegerSetReference(ifStmt->getIntegerSet());
- for (auto &childStmt : *ifStmt->getThen())
- visitStatement(&childStmt);
- if (ifStmt->hasElse())
- for (auto &childStmt : *ifStmt->getElse())
- visitStatement(&childStmt);
+void ModuleState::visitIfInst(const IfInst *ifInst) {
+ recordIntegerSetReference(ifInst->getIntegerSet());
+ for (auto &childInst : *ifInst->getThen())
+ visitInstruction(&childInst);
+ if (ifInst->hasElse())
+ for (auto &childInst : *ifInst->getElse())
+ visitInstruction(&childInst);
}
-void ModuleState::visitForStmt(const ForStmt *forStmt) {
- AffineMap lbMap = forStmt->getLowerBoundMap();
+void ModuleState::visitForInst(const ForInst *forInst) {
+ AffineMap lbMap = forInst->getLowerBoundMap();
if (!hasShorthandForm(lbMap))
recordAffineMapReference(lbMap);
- AffineMap ubMap = forStmt->getUpperBoundMap();
+ AffineMap ubMap = forInst->getUpperBoundMap();
if (!hasShorthandForm(ubMap))
recordAffineMapReference(ubMap);
- for (auto &childStmt : *forStmt->getBody())
- visitStatement(&childStmt);
+ for (auto &childInst : *forInst->getBody())
+ visitInstruction(&childInst);
}
-void ModuleState::visitOperationInst(const OperationInst *opStmt) {
- for (auto attr : opStmt->getAttrs())
+void ModuleState::visitOperationInst(const OperationInst *opInst) {
+ for (auto attr : opInst->getAttrs())
visitAttribute(attr.second);
}
-void ModuleState::visitStatement(const Statement *stmt) {
- switch (stmt->getKind()) {
- case Statement::Kind::If:
- return visitIfStmt(cast<IfStmt>(stmt));
- case Statement::Kind::For:
- return visitForStmt(cast<ForStmt>(stmt));
- case Statement::Kind::OperationInst:
- return visitOperationInst(cast<OperationInst>(stmt));
+void ModuleState::visitInstruction(const Instruction *inst) {
+ switch (inst->getKind()) {
+ case Instruction::Kind::If:
+ return visitIfInst(cast<IfInst>(inst));
+ case Instruction::Kind::For:
+ return visitForInst(cast<ForInst>(inst));
+ case Instruction::Kind::OperationInst:
+ return visitOperationInst(cast<OperationInst>(inst));
default:
return;
}
@@ -232,8 +232,8 @@ void ModuleState::visitStatement(const Statement *stmt) {
void ModuleState::visitMLFunction(const Function *fn) {
visitType(fn->getType());
- for (auto &stmt : *fn->getBody()) {
- ModuleState::visitStatement(&stmt);
+ for (auto &inst : *fn->getBody()) {
+ ModuleState::visitInstruction(&inst);
}
}
@@ -909,11 +909,11 @@ public:
void printMLFunctionSignature();
void printOtherFunctionSignature();
- // Methods to print statements.
- void print(const Statement *stmt);
+ // Methods to print instructions.
+ void print(const Instruction *inst);
void print(const OperationInst *inst);
- void print(const ForStmt *stmt);
- void print(const IfStmt *stmt);
+ void print(const ForInst *inst);
+ void print(const IfInst *inst);
void print(const Block *block);
void printOperation(const OperationInst *op);
@@ -959,7 +959,7 @@ public:
void printDimAndSymbolList(ArrayRef<InstOperand> ops, unsigned numDims);
void printBound(AffineBound bound, const char *prefix);
- // Number of spaces used for indenting nested statements.
+ // Number of spaces used for indenting nested instructions.
const static unsigned indentWidth = 2;
protected:
@@ -1019,22 +1019,22 @@ void FunctionPrinter::numberValuesInBlock(const Block &block) {
// We number instruction that have results, and we only number the first
// result.
switch (inst.getKind()) {
- case Statement::Kind::OperationInst: {
+ case Instruction::Kind::OperationInst: {
auto *opInst = cast<OperationInst>(&inst);
if (opInst->getNumResults() != 0)
numberValueID(opInst->getResult(0));
break;
}
- case Statement::Kind::For: {
- auto *forInst = cast<ForStmt>(&inst);
+ case Instruction::Kind::For: {
+ auto *forInst = cast<ForInst>(&inst);
// Number the induction variable.
numberValueID(forInst);
// Recursively number the stuff in the body.
numberValuesInBlock(*forInst->getBody());
break;
}
- case Statement::Kind::If: {
- auto *ifInst = cast<IfStmt>(&inst);
+ case Instruction::Kind::If: {
+ auto *ifInst = cast<IfInst>(&inst);
numberValuesInBlock(*ifInst->getThen());
if (auto *elseBlock = ifInst->getElse())
numberValuesInBlock(*elseBlock);
@@ -1086,7 +1086,7 @@ void FunctionPrinter::numberValueID(const Value *value) {
// done with it.
valueIDs[value] = nextValueID++;
return;
- case Value::Kind::ForStmt:
+ case Value::Kind::ForInst:
specialName << 'i' << nextLoopID++;
break;
}
@@ -1220,21 +1220,21 @@ void FunctionPrinter::print(const Block *block) {
currentIndent += indentWidth;
- for (auto &stmt : block->getInstructions()) {
- print(&stmt);
+ for (auto &inst : block->getInstructions()) {
+ print(&inst);
os << '\n';
}
currentIndent -= indentWidth;
}
-void FunctionPrinter::print(const Statement *stmt) {
- switch (stmt->getKind()) {
- case Statement::Kind::OperationInst:
- return print(cast<OperationInst>(stmt));
- case Statement::Kind::For:
- return print(cast<ForStmt>(stmt));
- case Statement::Kind::If:
- return print(cast<IfStmt>(stmt));
+void FunctionPrinter::print(const Instruction *inst) {
+ switch (inst->getKind()) {
+ case Instruction::Kind::OperationInst:
+ return print(cast<OperationInst>(inst));
+ case Instruction::Kind::For:
+ return print(cast<ForInst>(inst));
+ case Instruction::Kind::If:
+ return print(cast<IfInst>(inst));
}
}
@@ -1243,33 +1243,33 @@ void FunctionPrinter::print(const OperationInst *inst) {
printOperation(inst);
}
-void FunctionPrinter::print(const ForStmt *stmt) {
+void FunctionPrinter::print(const ForInst *inst) {
os.indent(currentIndent) << "for ";
- printOperand(stmt);
+ printOperand(inst);
os << " = ";
- printBound(stmt->getLowerBound(), "max");
+ printBound(inst->getLowerBound(), "max");
os << " to ";
- printBound(stmt->getUpperBound(), "min");
+ printBound(inst->getUpperBound(), "min");
- if (stmt->getStep() != 1)
- os << " step " << stmt->getStep();
+ if (inst->getStep() != 1)
+ os << " step " << inst->getStep();
os << " {\n";
- print(stmt->getBody());
+ print(inst->getBody());
os.indent(currentIndent) << "}";
}
-void FunctionPrinter::print(const IfStmt *stmt) {
+void FunctionPrinter::print(const IfInst *inst) {
os.indent(currentIndent) << "if ";
- IntegerSet set = stmt->getIntegerSet();
+ IntegerSet set = inst->getIntegerSet();
printIntegerSetReference(set);
- printDimAndSymbolList(stmt->getInstOperands(), set.getNumDims());
+ printDimAndSymbolList(inst->getInstOperands(), set.getNumDims());
os << " {\n";
- print(stmt->getThen());
+ print(inst->getThen());
os.indent(currentIndent) << "}";
- if (stmt->hasElse()) {
+ if (inst->hasElse()) {
os << " else {\n";
- print(stmt->getElse());
+ print(inst->getElse());
os.indent(currentIndent) << "}";
}
}
@@ -1280,7 +1280,7 @@ void FunctionPrinter::printValueID(const Value *value,
auto lookupValue = value;
// If this is a reference to the result of a multi-result instruction or
- // statement, print out the # identifier and make sure to map our lookup
+ // instruction, print out the # identifier and make sure to map our lookup
// to the first result of the instruction.
if (auto *result = dyn_cast<InstResult>(value)) {
if (result->getOwner()->getNumResults() != 1) {
@@ -1493,8 +1493,8 @@ void Value::print(raw_ostream &os) const {
return;
case Value::Kind::InstResult:
return getDefiningInst()->print(os);
- case Value::Kind::ForStmt:
- return cast<ForStmt>(this)->print(os);
+ case Value::Kind::ForInst:
+ return cast<ForInst>(this)->print(os);
}
}
diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp
index c7e84194c35..2efba2bbf69 100644
--- a/mlir/lib/IR/Block.cpp
+++ b/mlir/lib/IR/Block.cpp
@@ -26,16 +26,16 @@ Block::~Block() {
llvm::DeleteContainerPointers(arguments);
}
-/// Returns the closest surrounding statement that contains this block or
-/// nullptr if this is a top-level statement block.
-Statement *Block::getContainingInst() {
+/// Returns the closest surrounding instruction that contains this block or
+/// nullptr if this is a top-level instruction block.
+Instruction *Block::getContainingInst() {
return parent ? parent->getContainingInst() : nullptr;
}
Function *Block::getFunction() {
Block *block = this;
- while (auto *stmt = block->getContainingInst()) {
- block = stmt->getBlock();
+ while (auto *inst = block->getContainingInst()) {
+ block = inst->getBlock();
if (!block)
return nullptr;
}
@@ -49,11 +49,11 @@ Function *Block::getFunction() {
/// the latter fails.
const Instruction *
Block::findAncestorInstInBlock(const Instruction &inst) const {
- // Traverse up the statement hierarchy starting from the owner of operand to
- // find the ancestor statement that resides in the block of 'forStmt'.
+ // Traverse up the instruction hierarchy starting from the owner of operand to
+ // find the ancestor instruction that resides in the block of 'forInst'.
const auto *currInst = &inst;
while (currInst->getBlock() != this) {
- currInst = currInst->getParentStmt();
+ currInst = currInst->getParentInst();
if (!currInst)
return nullptr;
}
@@ -106,10 +106,10 @@ OperationInst *Block::getTerminator() {
// Check if the last instruction is a terminator.
auto &backInst = back();
- auto *opStmt = dyn_cast<OperationInst>(&backInst);
- if (!opStmt || !opStmt->isTerminator())
+ auto *opInst = dyn_cast<OperationInst>(&backInst);
+ if (!opInst || !opInst->isTerminator())
return nullptr;
- return opStmt;
+ return opInst;
}
/// Return true if this block has no predecessors.
@@ -184,10 +184,10 @@ Block *Block::splitBlock(iterator splitBefore) {
BlockList::BlockList(Function *container) : container(container) {}
-BlockList::BlockList(Statement *container) : container(container) {}
+BlockList::BlockList(Instruction *container) : container(container) {}
-Statement *BlockList::getContainingInst() {
- return container.dyn_cast<Statement *>();
+Instruction *BlockList::getContainingInst() {
+ return container.dyn_cast<Instruction *>();
}
Function *BlockList::getContainingFunction() {
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index a9eb6fe8c8a..4c7c8ddae81 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -268,7 +268,7 @@ AffineMap Builder::getShiftedAffineMap(AffineMap map, int64_t shift) {
}
//===----------------------------------------------------------------------===//
-// Statements.
+// Instructions.
//===----------------------------------------------------------------------===//
/// Add new basic block and set the insertion point to the end of it. If an
@@ -298,25 +298,25 @@ OperationInst *FuncBuilder::createOperation(const OperationState &state) {
return op;
}
-ForStmt *FuncBuilder::createFor(Location location, ArrayRef<Value *> lbOperands,
+ForInst *FuncBuilder::createFor(Location location, ArrayRef<Value *> lbOperands,
AffineMap lbMap, ArrayRef<Value *> ubOperands,
AffineMap ubMap, int64_t step) {
- auto *stmt =
- ForStmt::create(location, lbOperands, lbMap, ubOperands, ubMap, step);
- block->getInstructions().insert(insertPoint, stmt);
- return stmt;
+ auto *inst =
+ ForInst::create(location, lbOperands, lbMap, ubOperands, ubMap, step);
+ block->getInstructions().insert(insertPoint, inst);
+ return inst;
}
-ForStmt *FuncBuilder::createFor(Location location, int64_t lb, int64_t ub,
+ForInst *FuncBuilder::createFor(Location location, int64_t lb, int64_t ub,
int64_t step) {
auto lbMap = AffineMap::getConstantMap(lb, context);
auto ubMap = AffineMap::getConstantMap(ub, context);
return createFor(location, {}, lbMap, {}, ubMap, step);
}
-IfStmt *FuncBuilder::createIf(Location location, ArrayRef<Value *> operands,
+IfInst *FuncBuilder::createIf(Location location, ArrayRef<Value *> operands,
IntegerSet set) {
- auto *stmt = IfStmt::create(location, operands, set);
- block->getInstructions().insert(insertPoint, stmt);
- return stmt;
+ auto *inst = IfInst::create(location, operands, set);
+ block->getInstructions().insert(insertPoint, inst);
+ return inst;
}
diff --git a/mlir/lib/IR/Function.cpp b/mlir/lib/IR/Function.cpp
index cbe84e10247..bacb504683b 100644
--- a/mlir/lib/IR/Function.cpp
+++ b/mlir/lib/IR/Function.cpp
@@ -18,9 +18,9 @@
#include "mlir/IR/Function.h"
#include "AttributeListStorage.h"
#include "mlir/IR/Attributes.h"
+#include "mlir/IR/InstVisitor.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
-#include "mlir/IR/StmtVisitor.h"
#include "mlir/IR/Types.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/StringRef.h"
@@ -161,21 +161,21 @@ bool Function::emitError(const Twine &message) const {
// Function implementation.
//===----------------------------------------------------------------------===//
-const OperationInst *Function::getReturnStmt() const {
+const OperationInst *Function::getReturn() const {
return cast<OperationInst>(&getBody()->back());
}
-OperationInst *Function::getReturnStmt() {
+OperationInst *Function::getReturn() {
return cast<OperationInst>(&getBody()->back());
}
void Function::walk(std::function<void(OperationInst *)> callback) {
- struct Walker : public StmtWalker<Walker> {
+ struct Walker : public InstWalker<Walker> {
std::function<void(OperationInst *)> const &callback;
Walker(std::function<void(OperationInst *)> const &callback)
: callback(callback) {}
- void visitOperationInst(OperationInst *opStmt) { callback(opStmt); }
+ void visitOperationInst(OperationInst *opInst) { callback(opInst); }
};
Walker v(callback);
@@ -183,12 +183,12 @@ void Function::walk(std::function<void(OperationInst *)> callback) {
}
void Function::walkPostOrder(std::function<void(OperationInst *)> callback) {
- struct Walker : public StmtWalker<Walker> {
+ struct Walker : public InstWalker<Walker> {
std::function<void(OperationInst *)> const &callback;
Walker(std::function<void(OperationInst *)> const &callback)
: callback(callback) {}
- void visitOperationInst(OperationInst *opStmt) { callback(opStmt); }
+ void visitOperationInst(OperationInst *opInst) { callback(opInst); }
};
Walker v(callback);
diff --git a/mlir/lib/IR/Statement.cpp b/mlir/lib/IR/Instruction.cpp
index 6bd9944bb65..92f3c4ecba3 100644
--- a/mlir/lib/IR/Statement.cpp
+++ b/mlir/lib/IR/Instruction.cpp
@@ -1,4 +1,5 @@
-//===- Statement.cpp - MLIR Statement Classes ----------------------------===//
+//===- Instruction.cpp - MLIR Instruction Classes
+//----------------------------===//
//
// Copyright 2019 The MLIR Authors.
//
@@ -20,10 +21,10 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Function.h"
+#include "mlir/IR/InstVisitor.h"
+#include "mlir/IR/Instructions.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/MLIRContext.h"
-#include "mlir/IR/Statements.h"
-#include "mlir/IR/StmtVisitor.h"
#include "llvm/ADT/DenseMap.h"
using namespace mlir;
@@ -54,41 +55,43 @@ template <> unsigned BlockOperand::getOperandNumber() const {
}
//===----------------------------------------------------------------------===//
-// Statement
+// Instruction
//===----------------------------------------------------------------------===//
-// Statements are deleted through the destroy() member because we don't have
+// Instructions are deleted through the destroy() member because we don't have
// a virtual destructor.
-Statement::~Statement() {
- assert(block == nullptr && "statement destroyed but still in a block");
+Instruction::~Instruction() {
+ assert(block == nullptr && "instruction destroyed but still in a block");
}
-/// Destroy this statement or one of its subclasses.
-void Statement::destroy() {
+/// Destroy this instruction or one of its subclasses.
+void Instruction::destroy() {
switch (this->getKind()) {
case Kind::OperationInst:
cast<OperationInst>(this)->destroy();
break;
case Kind::For:
- delete cast<ForStmt>(this);
+ delete cast<ForInst>(this);
break;
case Kind::If:
- delete cast<IfStmt>(this);
+ delete cast<IfInst>(this);
break;
}
}
-Statement *Statement::getParentStmt() const {
+Instruction *Instruction::getParentInst() const {
return block ? block->getContainingInst() : nullptr;
}
-Function *Statement::getFunction() const {
+Function *Instruction::getFunction() const {
return block ? block->getFunction() : nullptr;
}
-Value *Statement::getOperand(unsigned idx) { return getInstOperand(idx).get(); }
+Value *Instruction::getOperand(unsigned idx) {
+ return getInstOperand(idx).get();
+}
-const Value *Statement::getOperand(unsigned idx) const {
+const Value *Instruction::getOperand(unsigned idx) const {
return getInstOperand(idx).get();
}
@@ -96,12 +99,12 @@ const Value *Statement::getOperand(unsigned idx) const {
// it is an induction variable, or it is a result of affine apply operation
// with dimension id arguments.
bool Value::isValidDim() const {
- if (auto *stmt = getDefiningInst()) {
- // Top level statement or constant operation is ok.
- if (stmt->getParentStmt() == nullptr || stmt->isa<ConstantOp>())
+ if (auto *inst = getDefiningInst()) {
+ // Top level instruction or constant operation is ok.
+ if (inst->getParentInst() == nullptr || inst->isa<ConstantOp>())
return true;
// Affine apply operation is ok if all of its operands are ok.
- if (auto op = stmt->dyn_cast<AffineApplyOp>())
+ if (auto op = inst->dyn_cast<AffineApplyOp>())
return op->isValidDim();
return false;
}
@@ -114,12 +117,12 @@ bool Value::isValidDim() const {
// the top level, or it is a result of affine apply operation with symbol
// arguments.
bool Value::isValidSymbol() const {
- if (auto *stmt = getDefiningInst()) {
- // Top level statement or constant operation is ok.
- if (stmt->getParentStmt() == nullptr || stmt->isa<ConstantOp>())
+ if (auto *inst = getDefiningInst()) {
+ // Top level instruction or constant operation is ok.
+ if (inst->getParentInst() == nullptr || inst->isa<ConstantOp>())
return true;
// Affine apply operation is ok if all of its operands are ok.
- if (auto op = stmt->dyn_cast<AffineApplyOp>())
+ if (auto op = inst->dyn_cast<AffineApplyOp>())
return op->isValidSymbol();
return false;
}
@@ -128,42 +131,42 @@ bool Value::isValidSymbol() const {
return isa<BlockArgument>(this);
}
-void Statement::setOperand(unsigned idx, Value *value) {
+void Instruction::setOperand(unsigned idx, Value *value) {
getInstOperand(idx).set(value);
}
-unsigned Statement::getNumOperands() const {
+unsigned Instruction::getNumOperands() const {
switch (getKind()) {
case Kind::OperationInst:
return cast<OperationInst>(this)->getNumOperands();
case Kind::For:
- return cast<ForStmt>(this)->getNumOperands();
+ return cast<ForInst>(this)->getNumOperands();
case Kind::If:
- return cast<IfStmt>(this)->getNumOperands();
+ return cast<IfInst>(this)->getNumOperands();
}
}
-MutableArrayRef<InstOperand> Statement::getInstOperands() {
+MutableArrayRef<InstOperand> Instruction::getInstOperands() {
switch (getKind()) {
case Kind::OperationInst:
return cast<OperationInst>(this)->getInstOperands();
case Kind::For:
- return cast<ForStmt>(this)->getInstOperands();
+ return cast<ForInst>(this)->getInstOperands();
case Kind::If:
- return cast<IfStmt>(this)->getInstOperands();
+ return cast<IfInst>(this)->getInstOperands();
}
}
-/// Emit a note about this statement, reporting up to any diagnostic
+/// Emit a note about this instruction, reporting up to any diagnostic
/// handlers that may be listening.
-void Statement::emitNote(const Twine &message) const {
+void Instruction::emitNote(const Twine &message) const {
getContext()->emitDiagnostic(getLoc(), message,
MLIRContext::DiagnosticKind::Note);
}
-/// Emit a warning about this statement, reporting up to any diagnostic
+/// Emit a warning about this instruction, reporting up to any diagnostic
/// handlers that may be listening.
-void Statement::emitWarning(const Twine &message) const {
+void Instruction::emitWarning(const Twine &message) const {
getContext()->emitDiagnostic(getLoc(), message,
MLIRContext::DiagnosticKind::Warning);
}
@@ -172,80 +175,80 @@ void Statement::emitWarning(const Twine &message) const {
/// any diagnostic handlers that may be listening. This function always
/// returns true. NOTE: This may terminate the containing application, only
/// use when the IR is in an inconsistent state.
-bool Statement::emitError(const Twine &message) const {
+bool Instruction::emitError(const Twine &message) const {
return getContext()->emitError(getLoc(), message);
}
-// Returns whether the Statement is a terminator.
-bool Statement::isTerminator() const {
+// Returns whether the Instruction is a terminator.
+bool Instruction::isTerminator() const {
if (auto *op = dyn_cast<OperationInst>(this))
return op->isTerminator();
return false;
}
//===----------------------------------------------------------------------===//
-// ilist_traits for Statement
+// ilist_traits for Instruction
//===----------------------------------------------------------------------===//
-void llvm::ilist_traits<::mlir::Statement>::deleteNode(Statement *stmt) {
- stmt->destroy();
+void llvm::ilist_traits<::mlir::Instruction>::deleteNode(Instruction *inst) {
+ inst->destroy();
}
-Block *llvm::ilist_traits<::mlir::Statement>::getContainingBlock() {
+Block *llvm::ilist_traits<::mlir::Instruction>::getContainingBlock() {
size_t Offset(size_t(&((Block *)nullptr->*Block::getSublistAccess(nullptr))));
- iplist<Statement> *Anchor(static_cast<iplist<Statement> *>(this));
+ iplist<Instruction> *Anchor(static_cast<iplist<Instruction> *>(this));
return reinterpret_cast<Block *>(reinterpret_cast<char *>(Anchor) - Offset);
}
-/// This is a trait method invoked when a statement is added to a block. We
+/// This is a trait method invoked when a instruction is added to a block. We
/// keep the block pointer up to date.
-void llvm::ilist_traits<::mlir::Statement>::addNodeToList(Statement *stmt) {
- assert(!stmt->getBlock() && "already in a statement block!");
- stmt->block = getContainingBlock();
+void llvm::ilist_traits<::mlir::Instruction>::addNodeToList(Instruction *inst) {
+ assert(!inst->getBlock() && "already in a instruction block!");
+ inst->block = getContainingBlock();
}
-/// This is a trait method invoked when a statement is removed from a block.
+/// This is a trait method invoked when a instruction is removed from a block.
/// We keep the block pointer up to date.
-void llvm::ilist_traits<::mlir::Statement>::removeNodeFromList(
- Statement *stmt) {
- assert(stmt->block && "not already in a statement block!");
- stmt->block = nullptr;
+void llvm::ilist_traits<::mlir::Instruction>::removeNodeFromList(
+ Instruction *inst) {
+ assert(inst->block && "not already in a instruction block!");
+ inst->block = nullptr;
}
-/// This is a trait method invoked when a statement is moved from one block
+/// This is a trait method invoked when a instruction is moved from one block
/// to another. We keep the block pointer up to date.
-void llvm::ilist_traits<::mlir::Statement>::transferNodesFromList(
- ilist_traits<Statement> &otherList, stmt_iterator first,
- stmt_iterator last) {
- // If we are transferring statements within the same block, the block
+void llvm::ilist_traits<::mlir::Instruction>::transferNodesFromList(
+ ilist_traits<Instruction> &otherList, inst_iterator first,
+ inst_iterator last) {
+ // If we are transferring instructions within the same block, the block
// pointer doesn't need to be updated.
Block *curParent = getContainingBlock();
if (curParent == otherList.getContainingBlock())
return;
- // Update the 'block' member of each statement.
+ // Update the 'block' member of each instruction.
for (; first != last; ++first)
first->block = curParent;
}
-/// Remove this statement (and its descendants) from its Block and delete
+/// Remove this instruction (and its descendants) from its Block and delete
/// all of them.
-void Statement::erase() {
- assert(getBlock() && "Statement has no block");
+void Instruction::erase() {
+ assert(getBlock() && "Instruction has no block");
getBlock()->getInstructions().erase(this);
}
-/// Unlink this statement from its current block and insert it right before
-/// `existingStmt` which may be in the same or another block in the same
+/// Unlink this instruction from its current block and insert it right before
+/// `existingInst` which may be in the same or another block in the same
/// function.
-void Statement::moveBefore(Statement *existingStmt) {
- moveBefore(existingStmt->getBlock(), existingStmt->getIterator());
+void Instruction::moveBefore(Instruction *existingInst) {
+ moveBefore(existingInst->getBlock(), existingInst->getIterator());
}
/// Unlink this operation instruction from its current basic block and insert
/// it right before `iterator` in the specified basic block.
-void Statement::moveBefore(Block *block,
- llvm::iplist<Statement>::iterator iterator) {
+void Instruction::moveBefore(Block *block,
+ llvm::iplist<Instruction>::iterator iterator) {
block->getInstructions().splice(iterator, getBlock()->getInstructions(),
getIterator());
}
@@ -253,7 +256,7 @@ void Statement::moveBefore(Block *block,
/// This drops all operand uses from this instruction, which is an essential
/// step in breaking cyclic dependences between references when they are to
/// be deleted.
-void Statement::dropAllReferences() {
+void Instruction::dropAllReferences() {
for (auto &op : getInstOperands())
op.drop();
@@ -284,17 +287,17 @@ OperationInst *OperationInst::create(Location location, OperationName name,
resultTypes.size(), numSuccessors, numSuccessors, numOperands);
void *rawMem = malloc(byteSize);
- // Initialize the OperationInst part of the statement.
- auto stmt = ::new (rawMem)
+ // Initialize the OperationInst part of the instruction.
+ auto inst = ::new (rawMem)
OperationInst(location, name, numOperands, resultTypes.size(),
numSuccessors, attributes, context);
// Initialize the results and operands.
- auto instResults = stmt->getInstResults();
+ auto instResults = inst->getInstResults();
for (unsigned i = 0, e = resultTypes.size(); i != e; ++i)
- new (&instResults[i]) InstResult(resultTypes[i], stmt);
+ new (&instResults[i]) InstResult(resultTypes[i], inst);
- auto InstOperands = stmt->getInstOperands();
+ auto InstOperands = inst->getInstOperands();
// Initialize normal operands.
unsigned operandIt = 0, operandE = operands.size();
@@ -305,7 +308,7 @@ OperationInst *OperationInst::create(Location location, OperationName name,
// separately below.
if (!operands[operandIt])
break;
- new (&InstOperands[nextOperand++]) InstOperand(stmt, operands[operandIt]);
+ new (&InstOperands[nextOperand++]) InstOperand(inst, operands[operandIt]);
}
unsigned currentSuccNum = 0;
@@ -313,13 +316,13 @@ OperationInst *OperationInst::create(Location location, OperationName name,
// Verify that the amount of sentinal operands is equivalent to the number
// of successors.
assert(currentSuccNum == numSuccessors);
- return stmt;
+ return inst;
}
- assert(stmt->isTerminator() &&
+ assert(inst->isTerminator() &&
"Sentinal operand found in non terminator operand list.");
- auto instBlockOperands = stmt->getBlockOperands();
- unsigned *succOperandCountIt = stmt->getTrailingObjects<unsigned>();
+ auto instBlockOperands = inst->getBlockOperands();
+ unsigned *succOperandCountIt = inst->getTrailingObjects<unsigned>();
unsigned *succOperandCountE = succOperandCountIt + numSuccessors;
(void)succOperandCountE;
@@ -338,12 +341,12 @@ OperationInst *OperationInst::create(Location location, OperationName name,
}
new (&instBlockOperands[currentSuccNum])
- BlockOperand(stmt, successors[currentSuccNum]);
+ BlockOperand(inst, successors[currentSuccNum]);
*succOperandCountIt = 0;
++currentSuccNum;
continue;
}
- new (&InstOperands[nextOperand++]) InstOperand(stmt, operands[operandIt]);
+ new (&InstOperands[nextOperand++]) InstOperand(inst, operands[operandIt]);
++(*succOperandCountIt);
}
@@ -351,7 +354,7 @@ OperationInst *OperationInst::create(Location location, OperationName name,
// successors.
assert(currentSuccNum == numSuccessors);
- return stmt;
+ return inst;
}
OperationInst::OperationInst(Location location, OperationName name,
@@ -359,7 +362,7 @@ OperationInst::OperationInst(Location location, OperationName name,
unsigned numSuccessors,
ArrayRef<NamedAttribute> attributes,
MLIRContext *context)
- : Statement(Kind::OperationInst, location), numOperands(numOperands),
+ : Instruction(Kind::OperationInst, location), numOperands(numOperands),
numResults(numResults), numSuccs(numSuccessors), name(name) {
#ifndef NDEBUG
for (auto elt : attributes)
@@ -524,10 +527,10 @@ bool OperationInst::emitOpError(const Twine &message) const {
}
//===----------------------------------------------------------------------===//
-// ForStmt
+// ForInst
//===----------------------------------------------------------------------===//
-ForStmt *ForStmt::create(Location location, ArrayRef<Value *> lbOperands,
+ForInst *ForInst::create(Location location, ArrayRef<Value *> lbOperands,
AffineMap lbMap, ArrayRef<Value *> ubOperands,
AffineMap ubMap, int64_t step) {
assert(lbOperands.size() == lbMap.getNumInputs() &&
@@ -537,39 +540,39 @@ ForStmt *ForStmt::create(Location location, ArrayRef<Value *> lbOperands,
assert(step > 0 && "step has to be a positive integer constant");
unsigned numOperands = lbOperands.size() + ubOperands.size();
- ForStmt *stmt = new ForStmt(location, numOperands, lbMap, ubMap, step);
+ ForInst *inst = new ForInst(location, numOperands, lbMap, ubMap, step);
unsigned i = 0;
for (unsigned e = lbOperands.size(); i != e; ++i)
- stmt->operands.emplace_back(InstOperand(stmt, lbOperands[i]));
+ inst->operands.emplace_back(InstOperand(inst, lbOperands[i]));
for (unsigned j = 0, e = ubOperands.size(); j != e; ++i, ++j)
- stmt->operands.emplace_back(InstOperand(stmt, ubOperands[j]));
+ inst->operands.emplace_back(InstOperand(inst, ubOperands[j]));
- return stmt;
+ return inst;
}
-ForStmt::ForStmt(Location location, unsigned numOperands, AffineMap lbMap,
+ForInst::ForInst(Location location, unsigned numOperands, AffineMap lbMap,
AffineMap ubMap, int64_t step)
- : Statement(Statement::Kind::For, location),
- Value(Value::Kind::ForStmt,
+ : Instruction(Instruction::Kind::For, location),
+ Value(Value::Kind::ForInst,
Type::getIndex(lbMap.getResult(0).getContext())),
body(this), lbMap(lbMap), ubMap(ubMap), step(step) {
- // The body of a for stmt always has one block.
+ // The body of a for inst always has one block.
body.push_back(new Block());
operands.reserve(numOperands);
}
-const AffineBound ForStmt::getLowerBound() const {
+const AffineBound ForInst::getLowerBound() const {
return AffineBound(*this, 0, lbMap.getNumInputs(), lbMap);
}
-const AffineBound ForStmt::getUpperBound() const {
+const AffineBound ForInst::getUpperBound() const {
return AffineBound(*this, lbMap.getNumInputs(), getNumOperands(), ubMap);
}
-void ForStmt::setLowerBound(ArrayRef<Value *> lbOperands, AffineMap map) {
+void ForInst::setLowerBound(ArrayRef<Value *> lbOperands, AffineMap map) {
assert(lbOperands.size() == map.getNumInputs());
assert(map.getNumResults() >= 1 && "bound map has at least one result");
@@ -586,7 +589,7 @@ void ForStmt::setLowerBound(ArrayRef<Value *> lbOperands, AffineMap map) {
this->lbMap = map;
}
-void ForStmt::setUpperBound(ArrayRef<Value *> ubOperands, AffineMap map) {
+void ForInst::setUpperBound(ArrayRef<Value *> ubOperands, AffineMap map) {
assert(ubOperands.size() == map.getNumInputs());
assert(map.getNumResults() >= 1 && "bound map has at least one result");
@@ -603,57 +606,57 @@ void ForStmt::setUpperBound(ArrayRef<Value *> ubOperands, AffineMap map) {
this->ubMap = map;
}
-void ForStmt::setLowerBoundMap(AffineMap map) {
+void ForInst::setLowerBoundMap(AffineMap map) {
assert(lbMap.getNumDims() == map.getNumDims() &&
lbMap.getNumSymbols() == map.getNumSymbols());
assert(map.getNumResults() >= 1 && "bound map has at least one result");
this->lbMap = map;
}
-void ForStmt::setUpperBoundMap(AffineMap map) {
+void ForInst::setUpperBoundMap(AffineMap map) {
assert(ubMap.getNumDims() == map.getNumDims() &&
ubMap.getNumSymbols() == map.getNumSymbols());
assert(map.getNumResults() >= 1 && "bound map has at least one result");
this->ubMap = map;
}
-bool ForStmt::hasConstantLowerBound() const { return lbMap.isSingleConstant(); }
+bool ForInst::hasConstantLowerBound() const { return lbMap.isSingleConstant(); }
-bool ForStmt::hasConstantUpperBound() const { return ubMap.isSingleConstant(); }
+bool ForInst::hasConstantUpperBound() const { return ubMap.isSingleConstant(); }
-int64_t ForStmt::getConstantLowerBound() const {
+int64_t ForInst::getConstantLowerBound() const {
return lbMap.getSingleConstantResult();
}
-int64_t ForStmt::getConstantUpperBound() const {
+int64_t ForInst::getConstantUpperBound() const {
return ubMap.getSingleConstantResult();
}
-void ForStmt::setConstantLowerBound(int64_t value) {
+void ForInst::setConstantLowerBound(int64_t value) {
setLowerBound({}, AffineMap::getConstantMap(value, getContext()));
}
-void ForStmt::setConstantUpperBound(int64_t value) {
+void ForInst::setConstantUpperBound(int64_t value) {
setUpperBound({}, AffineMap::getConstantMap(value, getContext()));
}
-ForStmt::operand_range ForStmt::getLowerBoundOperands() {
+ForInst::operand_range ForInst::getLowerBoundOperands() {
return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs()};
}
-ForStmt::const_operand_range ForStmt::getLowerBoundOperands() const {
+ForInst::const_operand_range ForInst::getLowerBoundOperands() const {
return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs()};
}
-ForStmt::operand_range ForStmt::getUpperBoundOperands() {
+ForInst::operand_range ForInst::getUpperBoundOperands() {
return {operand_begin() + getLowerBoundMap().getNumInputs(), operand_end()};
}
-ForStmt::const_operand_range ForStmt::getUpperBoundOperands() const {
+ForInst::const_operand_range ForInst::getUpperBoundOperands() const {
return {operand_begin() + getLowerBoundMap().getNumInputs(), operand_end()};
}
-bool ForStmt::matchingBoundOperandList() const {
+bool ForInst::matchingBoundOperandList() const {
if (lbMap.getNumDims() != ubMap.getNumDims() ||
lbMap.getNumSymbols() != ubMap.getNumSymbols())
return false;
@@ -668,46 +671,46 @@ bool ForStmt::matchingBoundOperandList() const {
}
//===----------------------------------------------------------------------===//
-// IfStmt
+// IfInst
//===----------------------------------------------------------------------===//
-IfStmt::IfStmt(Location location, unsigned numOperands, IntegerSet set)
- : Statement(Kind::If, location), thenClause(this), elseClause(nullptr),
+IfInst::IfInst(Location location, unsigned numOperands, IntegerSet set)
+ : Instruction(Kind::If, location), thenClause(this), elseClause(nullptr),
set(set) {
operands.reserve(numOperands);
- // The then of an 'if' stmt always has one block.
+ // The then of an 'if' inst always has one block.
thenClause.push_back(new Block());
}
-IfStmt::~IfStmt() {
+IfInst::~IfInst() {
if (elseClause)
delete elseClause;
- // An IfStmt's IntegerSet 'set' should not be deleted since it is
+ // An IfInst's IntegerSet 'set' should not be deleted since it is
// allocated through MLIRContext's bump pointer allocator.
}
-IfStmt *IfStmt::create(Location location, ArrayRef<Value *> operands,
+IfInst *IfInst::create(Location location, ArrayRef<Value *> operands,
IntegerSet set) {
unsigned numOperands = operands.size();
assert(numOperands == set.getNumOperands() &&
"operand cound does not match the integer set operand count");
- IfStmt *stmt = new IfStmt(location, numOperands, set);
+ IfInst *inst = new IfInst(location, numOperands, set);
for (auto *op : operands)
- stmt->operands.emplace_back(InstOperand(stmt, op));
+ inst->operands.emplace_back(InstOperand(inst, op));
- return stmt;
+ return inst;
}
-const AffineCondition IfStmt::getCondition() const {
+const AffineCondition IfInst::getCondition() const {
return AffineCondition(*this, set);
}
-MLIRContext *IfStmt::getContext() const {
- // Check for degenerate case of if statement with no operands.
+MLIRContext *IfInst::getContext() const {
+ // Check for degenerate case of if instruction with no operands.
// This is unlikely, but legal.
if (operands.empty())
return getFunction()->getContext();
@@ -716,16 +719,16 @@ MLIRContext *IfStmt::getContext() const {
}
//===----------------------------------------------------------------------===//
-// Statement Cloning
+// Instruction Cloning
//===----------------------------------------------------------------------===//
-/// Create a deep copy of this statement, remapping any operands that use
-/// values outside of the statement using the map that is provided (leaving
+/// Create a deep copy of this instruction, remapping any operands that use
+/// values outside of the instruction using the map that is provided (leaving
/// them alone if no entry is present). Replaces references to cloned
-/// sub-statements to the corresponding statement that is copied, and adds
+/// sub-instructions to the corresponding instruction that is copied, and adds
/// those mappings to the map.
-Statement *Statement::clone(DenseMap<const Value *, Value *> &operandMap,
- MLIRContext *context) const {
+Instruction *Instruction::clone(DenseMap<const Value *, Value *> &operandMap,
+ MLIRContext *context) const {
// If the specified value is in operandMap, return the remapped value.
// Otherwise return the value itself.
auto remapOperand = [&](const Value *value) -> Value * {
@@ -735,48 +738,48 @@ Statement *Statement::clone(DenseMap<const Value *, Value *> &operandMap,
SmallVector<Value *, 8> operands;
SmallVector<Block *, 2> successors;
- if (auto *opStmt = dyn_cast<OperationInst>(this)) {
- operands.reserve(getNumOperands() + opStmt->getNumSuccessors());
+ if (auto *opInst = dyn_cast<OperationInst>(this)) {
+ operands.reserve(getNumOperands() + opInst->getNumSuccessors());
- if (!opStmt->isTerminator()) {
+ if (!opInst->isTerminator()) {
// Non-terminators just add all the operands.
for (auto *opValue : getOperands())
operands.push_back(remapOperand(opValue));
} else {
// We add the operands separated by nullptr's for each successor.
- unsigned firstSuccOperand = opStmt->getNumSuccessors()
- ? opStmt->getSuccessorOperandIndex(0)
- : opStmt->getNumOperands();
- auto InstOperands = opStmt->getInstOperands();
+ unsigned firstSuccOperand = opInst->getNumSuccessors()
+ ? opInst->getSuccessorOperandIndex(0)
+ : opInst->getNumOperands();
+ auto InstOperands = opInst->getInstOperands();
unsigned i = 0;
for (; i != firstSuccOperand; ++i)
operands.push_back(remapOperand(InstOperands[i].get()));
- successors.reserve(opStmt->getNumSuccessors());
- for (unsigned succ = 0, e = opStmt->getNumSuccessors(); succ != e;
+ successors.reserve(opInst->getNumSuccessors());
+ for (unsigned succ = 0, e = opInst->getNumSuccessors(); succ != e;
++succ) {
- successors.push_back(const_cast<Block *>(opStmt->getSuccessor(succ)));
+ successors.push_back(const_cast<Block *>(opInst->getSuccessor(succ)));
// Add sentinel to delineate successor operands.
operands.push_back(nullptr);
// Remap the successors operands.
- for (auto *operand : opStmt->getSuccessorOperands(succ))
+ for (auto *operand : opInst->getSuccessorOperands(succ))
operands.push_back(remapOperand(operand));
}
}
SmallVector<Type, 8> resultTypes;
- resultTypes.reserve(opStmt->getNumResults());
- for (auto *result : opStmt->getResults())
+ resultTypes.reserve(opInst->getNumResults());
+ for (auto *result : opInst->getResults())
resultTypes.push_back(result->getType());
- auto *newOp = OperationInst::create(getLoc(), opStmt->getName(), operands,
- resultTypes, opStmt->getAttrs(),
+ auto *newOp = OperationInst::create(getLoc(), opInst->getName(), operands,
+ resultTypes, opInst->getAttrs(),
successors, context);
// Remember the mapping of any results.
- for (unsigned i = 0, e = opStmt->getNumResults(); i != e; ++i)
- operandMap[opStmt->getResult(i)] = newOp->getResult(i);
+ for (unsigned i = 0, e = opInst->getNumResults(); i != e; ++i)
+ operandMap[opInst->getResult(i)] = newOp->getResult(i);
return newOp;
}
@@ -784,43 +787,43 @@ Statement *Statement::clone(DenseMap<const Value *, Value *> &operandMap,
for (auto *opValue : getOperands())
operands.push_back(remapOperand(opValue));
- if (auto *forStmt = dyn_cast<ForStmt>(this)) {
- auto lbMap = forStmt->getLowerBoundMap();
- auto ubMap = forStmt->getUpperBoundMap();
+ if (auto *forInst = dyn_cast<ForInst>(this)) {
+ auto lbMap = forInst->getLowerBoundMap();
+ auto ubMap = forInst->getUpperBoundMap();
- auto *newFor = ForStmt::create(
+ auto *newFor = ForInst::create(
getLoc(), ArrayRef<Value *>(operands).take_front(lbMap.getNumInputs()),
lbMap, ArrayRef<Value *>(operands).take_back(ubMap.getNumInputs()),
- ubMap, forStmt->getStep());
+ ubMap, forInst->getStep());
// Remember the induction variable mapping.
- operandMap[forStmt] = newFor;
+ operandMap[forInst] = newFor;
// Recursively clone the body of the for loop.
- for (auto &subStmt : *forStmt->getBody())
- newFor->getBody()->push_back(subStmt.clone(operandMap, context));
+ for (auto &subInst : *forInst->getBody())
+ newFor->getBody()->push_back(subInst.clone(operandMap, context));
return newFor;
}
- // Otherwise, we must have an If statement.
- auto *ifStmt = cast<IfStmt>(this);
- auto *newIf = IfStmt::create(getLoc(), operands, ifStmt->getIntegerSet());
+ // Otherwise, we must have an If instruction.
+ auto *ifInst = cast<IfInst>(this);
+ auto *newIf = IfInst::create(getLoc(), operands, ifInst->getIntegerSet());
auto *resultThen = newIf->getThen();
- for (auto &childStmt : *ifStmt->getThen())
- resultThen->push_back(childStmt.clone(operandMap, context));
+ for (auto &childInst : *ifInst->getThen())
+ resultThen->push_back(childInst.clone(operandMap, context));
- if (ifStmt->hasElse()) {
+ if (ifInst->hasElse()) {
auto *resultElse = newIf->createElse();
- for (auto &childStmt : *ifStmt->getElse())
- resultElse->push_back(childStmt.clone(operandMap, context));
+ for (auto &childInst : *ifInst->getElse())
+ resultElse->push_back(childInst.clone(operandMap, context));
}
return newIf;
}
-Statement *Statement::clone(MLIRContext *context) const {
+Instruction *Instruction::clone(MLIRContext *context) const {
DenseMap<const Value *, Value *> operandMap;
return clone(operandMap, context);
}
diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index ccd7d65f7c8..9cd4355e4aa 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -17,10 +17,10 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Function.h"
+#include "mlir/IR/Instructions.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
-#include "mlir/IR/Statements.h"
using namespace mlir;
/// Form the OperationName for an op with the specified string. This either is
@@ -279,7 +279,7 @@ bool OpTrait::impl::verifyIsTerminator(const OperationInst *op) {
if (op->getFunction()->isML()) {
Block *block = op->getBlock();
if (!block || block->getContainingInst() || &block->back() != op)
- return op->emitOpError("must be the last statement in the ML function");
+ return op->emitOpError("must be the last instruction in the ML function");
} else {
const Block *block = op->getBlock();
if (!block || &block->back() != op)
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 8c41d488a8b..90d768c844e 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -16,7 +16,7 @@
// =============================================================================
#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/Statements.h"
+#include "mlir/IR/Instructions.h"
#include "mlir/IR/Value.h"
using namespace mlir;
diff --git a/mlir/lib/IR/Value.cpp b/mlir/lib/IR/Value.cpp
index c7a5e42dd99..a213f05a932 100644
--- a/mlir/lib/IR/Value.cpp
+++ b/mlir/lib/IR/Value.cpp
@@ -17,7 +17,7 @@
#include "mlir/IR/Value.h"
#include "mlir/IR/Function.h"
-#include "mlir/IR/Statements.h"
+#include "mlir/IR/Instructions.h"
using namespace mlir;
/// If this value is the result of an Instruction, return the instruction
@@ -35,8 +35,8 @@ Function *Value::getFunction() {
return cast<BlockArgument>(this)->getFunction();
case Value::Kind::InstResult:
return getDefiningInst()->getFunction();
- case Value::Kind::ForStmt:
- return cast<ForStmt>(this)->getFunction();
+ case Value::Kind::ForInst:
+ return cast<ForInst>(this)->getFunction();
}
}
@@ -59,10 +59,10 @@ MLIRContext *IROperandOwner::getContext() const {
switch (getKind()) {
case Kind::OperationInst:
return cast<OperationInst>(this)->getContext();
- case Kind::ForStmt:
- return cast<ForStmt>(this)->getContext();
- case Kind::IfStmt:
- return cast<IfStmt>(this)->getContext();
+ case Kind::ForInst:
+ return cast<ForInst>(this)->getContext();
+ case Kind::IfInst:
+ return cast<IfInst>(this)->getContext();
}
}
diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index 6cc1aba72b3..3f05a4a145a 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -26,12 +26,12 @@
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/InstVisitor.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/OpImplementation.h"
-#include "mlir/IR/StmtVisitor.h"
#include "mlir/IR/Types.h"
#include "mlir/Support/STLExtras.h"
#include "mlir/Transforms/Utils.h"
@@ -2071,7 +2071,7 @@ FunctionParser::~FunctionParser() {
}
}
-/// Parse a SSA operand for an instruction or statement.
+/// Parse a SSA operand for an instruction or instruction.
///
/// ssa-use ::= ssa-id
///
@@ -2716,7 +2716,7 @@ ParseResult CFGFunctionParser::parseFunctionBody() {
/// Basic block declaration.
///
-/// basic-block ::= bb-label instruction* terminator-stmt
+/// basic-block ::= bb-label instruction* terminator-inst
/// bb-label ::= bb-id bb-arg-list? `:`
/// bb-id ::= bare-id
/// bb-arg-list ::= `(` ssa-id-and-type-list? `)`
@@ -2786,16 +2786,16 @@ private:
/// more specific builder type.
FuncBuilder builder;
- ParseResult parseForStmt();
+ ParseResult parseForInst();
ParseResult parseIntConstant(int64_t &val);
ParseResult parseDimAndSymbolList(SmallVectorImpl<Value *> &operands,
unsigned numDims, unsigned numOperands,
const char *affineStructName);
ParseResult parseBound(SmallVectorImpl<Value *> &operands, AffineMap &map,
bool isLower);
- ParseResult parseIfStmt();
+ ParseResult parseIfInst();
ParseResult parseElseClause(Block *elseClause);
- ParseResult parseStatements(Block *block);
+ ParseResult parseInstructions(Block *block);
ParseResult parseBlock(Block *block);
bool parseSuccessorAndUseList(Block *&dest,
@@ -2809,19 +2809,19 @@ private:
ParseResult MLFunctionParser::parseFunctionBody() {
auto braceLoc = getToken().getLoc();
- // Parse statements in this function.
+ // Parse instructions in this function.
if (parseBlock(function->getBody()))
return ParseFailure;
return finalizeFunction(function, braceLoc);
}
-/// For statement.
+/// For instruction.
///
-/// ml-for-stmt ::= `for` ssa-id `=` lower-bound `to` upper-bound
-/// (`step` integer-literal)? `{` ml-stmt* `}`
+/// ml-for-inst ::= `for` ssa-id `=` lower-bound `to` upper-bound
+/// (`step` integer-literal)? `{` ml-inst* `}`
///
-ParseResult MLFunctionParser::parseForStmt() {
+ParseResult MLFunctionParser::parseForInst() {
consumeToken(Token::kw_for);
// Parse induction variable.
@@ -2862,23 +2862,23 @@ ParseResult MLFunctionParser::parseForStmt() {
return emitError("step has to be a positive integer");
}
- // Create for statement.
- ForStmt *forStmt =
+ // Create for instruction.
+ ForInst *forInst =
builder.createFor(getEncodedSourceLocation(loc), lbOperands, lbMap,
ubOperands, ubMap, step);
// Create SSA value definition for the induction variable.
- if (addDefinition({inductionVariableName, 0, loc}, forStmt))
+ if (addDefinition({inductionVariableName, 0, loc}, forInst))
return ParseFailure;
- // If parsing of the for statement body fails,
- // MLIR contains for statement with those nested statements that have been
+ // If parsing of the for instruction body fails,
+ // MLIR contains for instruction with those nested instructions that have been
// successfully parsed.
- if (parseBlock(forStmt->getBody()))
+ if (parseBlock(forInst->getBody()))
return ParseFailure;
// Reset insertion point to the current block.
- builder.setInsertionPointToEnd(forStmt->getBlock());
+ builder.setInsertionPointToEnd(forInst->getBlock());
return ParseSuccess;
}
@@ -3007,7 +3007,7 @@ ParseResult MLFunctionParser::parseBound(SmallVectorImpl<Value *> &operands,
// Create an identity map using dim id for an induction variable and
// symbol otherwise. This representation is optimized for storage.
// Analysis passes may expand it into a multi-dimensional map if desired.
- if (isa<ForStmt>(operands[0]))
+ if (isa<ForInst>(operands[0]))
map = builder.getDimIdentityMap();
else
map = builder.getSymbolIdentityMap();
@@ -3095,14 +3095,14 @@ IntegerSet Parser::parseIntegerSetInline() {
return set;
}
-/// If statement.
+/// If instruction.
///
-/// ml-if-head ::= `if` ml-if-cond `{` ml-stmt* `}`
-/// | ml-if-head `else` `if` ml-if-cond `{` ml-stmt* `}`
-/// ml-if-stmt ::= ml-if-head
-/// | ml-if-head `else` `{` ml-stmt* `}`
+/// ml-if-head ::= `if` ml-if-cond `{` ml-inst* `}`
+/// | ml-if-head `else` `if` ml-if-cond `{` ml-inst* `}`
+/// ml-if-inst ::= ml-if-head
+/// | ml-if-head `else` `{` ml-inst* `}`
///
-ParseResult MLFunctionParser::parseIfStmt() {
+ParseResult MLFunctionParser::parseIfInst() {
auto loc = getToken().getLoc();
consumeToken(Token::kw_if);
@@ -3115,25 +3115,25 @@ ParseResult MLFunctionParser::parseIfStmt() {
"integer set"))
return ParseFailure;
- IfStmt *ifStmt =
+ IfInst *ifInst =
builder.createIf(getEncodedSourceLocation(loc), operands, set);
- Block *thenClause = ifStmt->getThen();
+ Block *thenClause = ifInst->getThen();
- // When parsing of an if statement body fails, the IR contains
- // the if statement with the portion of the body that has been
+ // When parsing of an if instruction body fails, the IR contains
+ // the if instruction with the portion of the body that has been
// successfully parsed.
if (parseBlock(thenClause))
return ParseFailure;
if (consumeIf(Token::kw_else)) {
- auto *elseClause = ifStmt->createElse();
+ auto *elseClause = ifInst->createElse();
if (parseElseClause(elseClause))
return ParseFailure;
}
// Reset insertion point to the current block.
- builder.setInsertionPointToEnd(ifStmt->getBlock());
+ builder.setInsertionPointToEnd(ifInst->getBlock());
return ParseSuccess;
}
@@ -3141,25 +3141,25 @@ ParseResult MLFunctionParser::parseIfStmt() {
ParseResult MLFunctionParser::parseElseClause(Block *elseClause) {
if (getToken().is(Token::kw_if)) {
builder.setInsertionPointToEnd(elseClause);
- return parseIfStmt();
+ return parseIfInst();
}
return parseBlock(elseClause);
}
///
-/// Parse a list of statements ending with `return` or `}`
+/// Parse a list of instructions ending with `return` or `}`
///
-ParseResult MLFunctionParser::parseStatements(Block *block) {
+ParseResult MLFunctionParser::parseInstructions(Block *block) {
auto createOpFunc = [&](const OperationState &state) -> OperationInst * {
return builder.createOperation(state);
};
builder.setInsertionPointToEnd(block);
- // Parse statements till we see '}' or 'return'.
- // Return statement is parsed separately to emit a more intuitive error
- // when '}' is missing after the return statement.
+ // Parse instructions till we see '}' or 'return'.
+ // Return instruction is parsed separately to emit a more intuitive error
+ // when '}' is missing after the return instruction.
while (getToken().isNot(Token::r_brace, Token::kw_return)) {
switch (getToken().getKind()) {
default:
@@ -3167,17 +3167,17 @@ ParseResult MLFunctionParser::parseStatements(Block *block) {
return ParseFailure;
break;
case Token::kw_for:
- if (parseForStmt())
+ if (parseForInst())
return ParseFailure;
break;
case Token::kw_if:
- if (parseIfStmt())
+ if (parseIfInst())
return ParseFailure;
break;
} // end switch
}
- // Parse the return statement.
+ // Parse the return instruction.
if (getToken().is(Token::kw_return))
if (parseOperation(createOpFunc))
return ParseFailure;
@@ -3186,12 +3186,12 @@ ParseResult MLFunctionParser::parseStatements(Block *block) {
}
///
-/// Parse `{` ml-stmt* `}`
+/// Parse `{` ml-inst* `}`
///
ParseResult MLFunctionParser::parseBlock(Block *block) {
- if (parseToken(Token::l_brace, "expected '{' before statement list") ||
- parseStatements(block) ||
- parseToken(Token::r_brace, "expected '}' after statement list"))
+ if (parseToken(Token::l_brace, "expected '{' before instruction list") ||
+ parseInstructions(block) ||
+ parseToken(Token::r_brace, "expected '}' after instruction list"))
return ParseFailure;
return ParseSuccess;
@@ -3429,7 +3429,7 @@ ParseResult ModuleParser::parseCFGFunc() {
/// ML function declarations.
///
/// ml-func ::= `mlfunc` ml-func-signature
-/// (`attributes` attribute-dict)? `{` ml-stmt* ml-return-stmt
+/// (`attributes` attribute-dict)? `{` ml-inst* ml-return-inst
/// `}`
///
ParseResult ModuleParser::parseMLFunc() {
diff --git a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp
index 0f130e19e26..20e8e0af214 100644
--- a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp
+++ b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp
@@ -21,9 +21,9 @@
//===----------------------------------------------------------------------===//
#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Instructions.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
-#include "mlir/IR/Statements.h"
#include "mlir/StandardOps/StandardOps.h"
#include "mlir/SuperVectorOps/SuperVectorOps.h"
#include "mlir/Support/FileUtilities.h"
diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp
index a5b45ba4098..80e3dd955c3 100644
--- a/mlir/lib/Transforms/CSE.cpp
+++ b/mlir/lib/Transforms/CSE.cpp
@@ -24,7 +24,7 @@
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
-#include "mlir/IR/StmtVisitor.h"
+#include "mlir/IR/InstVisitor.h"
#include "mlir/Pass.h"
#include "mlir/Support/Functional.h"
#include "mlir/Transforms/Passes.h"
@@ -207,24 +207,24 @@ struct CFGCSE : public CSEImpl {
};
/// Common sub-expression elimination for ML functions.
-struct MLCSE : public CSEImpl, StmtWalker<MLCSE> {
- using StmtWalker<MLCSE>::walk;
+struct MLCSE : public CSEImpl, InstWalker<MLCSE> {
+ using InstWalker<MLCSE>::walk;
void run(Function *f) {
- // Walk the function statements.
+ // Walk the function instructions.
walk(f);
// Finally, erase any redundant operations.
eraseDeadOperations();
}
- // Insert a scope for each statement range.
+ // Insert a scope for each instruction range.
template <class Iterator> void walk(Iterator Start, Iterator End) {
ScopedMapTy::ScopeTy scope(knownValues);
- StmtWalker<MLCSE>::walk(Start, End);
+ InstWalker<MLCSE>::walk(Start, End);
}
- void visitOperationInst(OperationInst *stmt) { simplifyOperation(stmt); }
+ void visitOperationInst(OperationInst *inst) { simplifyOperation(inst); }
};
} // end anonymous namespace
diff --git a/mlir/lib/Transforms/ComposeAffineMaps.cpp b/mlir/lib/Transforms/ComposeAffineMaps.cpp
index c97b83f8485..f5edf2d8b81 100644
--- a/mlir/lib/Transforms/ComposeAffineMaps.cpp
+++ b/mlir/lib/Transforms/ComposeAffineMaps.cpp
@@ -25,7 +25,7 @@
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/StmtVisitor.h"
+#include "mlir/IR/InstVisitor.h"
#include "mlir/Pass.h"
#include "mlir/StandardOps/StandardOps.h"
#include "mlir/Transforms/Passes.h"
@@ -36,20 +36,20 @@ using namespace mlir;
namespace {
-// ComposeAffineMaps walks stmt blocks in a Function, and for each
+// ComposeAffineMaps walks inst blocks in a Function, and for each
// AffineApplyOp, forward substitutes its results into any users which are
// also AffineApplyOps. After forward subtituting its results, AffineApplyOps
// with no remaining uses are collected and erased after the walk.
// TODO(andydavis) Remove this when Chris adds instruction combiner pass.
-struct ComposeAffineMaps : public FunctionPass, StmtWalker<ComposeAffineMaps> {
+struct ComposeAffineMaps : public FunctionPass, InstWalker<ComposeAffineMaps> {
std::vector<OperationInst *> affineApplyOpsToErase;
explicit ComposeAffineMaps() : FunctionPass(&ComposeAffineMaps::passID) {}
- using InstListType = llvm::iplist<Statement>;
+ using InstListType = llvm::iplist<Instruction>;
void walk(InstListType::iterator Start, InstListType::iterator End);
- void visitOperationInst(OperationInst *stmt);
+ void visitOperationInst(OperationInst *inst);
PassResult runOnMLFunction(Function *f) override;
- using StmtWalker<ComposeAffineMaps>::walk;
+ using InstWalker<ComposeAffineMaps>::walk;
static char passID;
};
@@ -66,14 +66,14 @@ void ComposeAffineMaps::walk(InstListType::iterator Start,
InstListType::iterator End) {
while (Start != End) {
walk(&(*Start));
- // Increment iterator after walk as visit function can mutate stmt list
+ // Increment iterator after walk as visit function can mutate inst list
// ahead of 'Start'.
++Start;
}
}
-void ComposeAffineMaps::visitOperationInst(OperationInst *opStmt) {
- if (auto affineApplyOp = opStmt->dyn_cast<AffineApplyOp>()) {
+void ComposeAffineMaps::visitOperationInst(OperationInst *opInst) {
+ if (auto affineApplyOp = opInst->dyn_cast<AffineApplyOp>()) {
forwardSubstitute(affineApplyOp);
bool allUsesEmpty = true;
for (auto *result : affineApplyOp->getInstruction()->getResults()) {
@@ -83,7 +83,7 @@ void ComposeAffineMaps::visitOperationInst(OperationInst *opStmt) {
}
}
if (allUsesEmpty) {
- affineApplyOpsToErase.push_back(opStmt);
+ affineApplyOpsToErase.push_back(opInst);
}
}
}
@@ -91,8 +91,8 @@ void ComposeAffineMaps::visitOperationInst(OperationInst *opStmt) {
PassResult ComposeAffineMaps::runOnMLFunction(Function *f) {
affineApplyOpsToErase.clear();
walk(f);
- for (auto *opStmt : affineApplyOpsToErase) {
- opStmt->erase();
+ for (auto *opInst : affineApplyOpsToErase) {
+ opInst->erase();
}
return success();
}
diff --git a/mlir/lib/Transforms/ConstantFold.cpp b/mlir/lib/Transforms/ConstantFold.cpp
index 08087777e72..f482e90d7ac 100644
--- a/mlir/lib/Transforms/ConstantFold.cpp
+++ b/mlir/lib/Transforms/ConstantFold.cpp
@@ -17,7 +17,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
-#include "mlir/IR/StmtVisitor.h"
+#include "mlir/IR/InstVisitor.h"
#include "mlir/Pass.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/Transforms/Utils.h"
@@ -26,20 +26,20 @@ using namespace mlir;
namespace {
/// Simple constant folding pass.
-struct ConstantFold : public FunctionPass, StmtWalker<ConstantFold> {
+struct ConstantFold : public FunctionPass, InstWalker<ConstantFold> {
ConstantFold() : FunctionPass(&ConstantFold::passID) {}
// All constants in the function post folding.
SmallVector<Value *, 8> existingConstants;
// Operations that were folded and that need to be erased.
- std::vector<OperationInst *> opStmtsToErase;
+ std::vector<OperationInst *> opInstsToErase;
using ConstantFactoryType = std::function<Value *(Attribute, Type)>;
bool foldOperation(OperationInst *op,
SmallVectorImpl<Value *> &existingConstants,
ConstantFactoryType constantFactory);
- void visitOperationInst(OperationInst *stmt);
- void visitForStmt(ForStmt *stmt);
+ void visitOperationInst(OperationInst *inst);
+ void visitForInst(ForInst *inst);
PassResult runOnCFGFunction(Function *f) override;
PassResult runOnMLFunction(Function *f) override;
@@ -140,24 +140,24 @@ PassResult ConstantFold::runOnCFGFunction(Function *f) {
}
// Override the walker's operation visiter for constant folding.
-void ConstantFold::visitOperationInst(OperationInst *stmt) {
+void ConstantFold::visitOperationInst(OperationInst *inst) {
auto constantFactory = [&](Attribute value, Type type) -> Value * {
- FuncBuilder builder(stmt);
- return builder.create<ConstantOp>(stmt->getLoc(), value, type);
+ FuncBuilder builder(inst);
+ return builder.create<ConstantOp>(inst->getLoc(), value, type);
};
- if (!ConstantFold::foldOperation(stmt, existingConstants, constantFactory)) {
- opStmtsToErase.push_back(stmt);
+ if (!ConstantFold::foldOperation(inst, existingConstants, constantFactory)) {
+ opInstsToErase.push_back(inst);
}
}
-// Override the walker's 'for' statement visit for constant folding.
-void ConstantFold::visitForStmt(ForStmt *forStmt) {
- constantFoldBounds(forStmt);
+// Override the walker's 'for' instruction visit for constant folding.
+void ConstantFold::visitForInst(ForInst *forInst) {
+ constantFoldBounds(forInst);
}
PassResult ConstantFold::runOnMLFunction(Function *f) {
existingConstants.clear();
- opStmtsToErase.clear();
+ opInstsToErase.clear();
walk(f);
// At this point, these operations are dead, remove them.
@@ -165,8 +165,8 @@ PassResult ConstantFold::runOnMLFunction(Function *f) {
// side effects. When we have side effect modeling, we should verify that
// the operation is effect-free before we remove it. Until then this is
// close enough.
- for (auto *stmt : opStmtsToErase) {
- stmt->erase();
+ for (auto *inst : opInstsToErase) {
+ inst->erase();
}
// By the time we are done, we may have simplified a bunch of code, leaving
diff --git a/mlir/lib/Transforms/ConvertToCFG.cpp b/mlir/lib/Transforms/ConvertToCFG.cpp
index 821f35ca539..abce624b06f 100644
--- a/mlir/lib/Transforms/ConvertToCFG.cpp
+++ b/mlir/lib/Transforms/ConvertToCFG.cpp
@@ -21,9 +21,9 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/InstVisitor.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
-#include "mlir/IR/StmtVisitor.h"
#include "mlir/Pass.h"
#include "mlir/StandardOps/StandardOps.h"
#include "mlir/Support/Functional.h"
@@ -39,14 +39,14 @@ using namespace mlir;
namespace {
// Generates CFG function equivalent to the given ML function.
-class FunctionConverter : public StmtVisitor<FunctionConverter> {
+class FunctionConverter : public InstVisitor<FunctionConverter> {
public:
FunctionConverter(Function *cfgFunc) : cfgFunc(cfgFunc), builder(cfgFunc) {}
Function *convert(Function *mlFunc);
- void visitForStmt(ForStmt *forStmt);
- void visitIfStmt(IfStmt *ifStmt);
- void visitOperationInst(OperationInst *opStmt);
+ void visitForInst(ForInst *forInst);
+ void visitIfInst(IfInst *ifInst);
+ void visitOperationInst(OperationInst *opInst);
private:
Value *getConstantIndexValue(int64_t value);
@@ -64,49 +64,49 @@ private:
} // end anonymous namespace
// Return a vector of OperationInst's arguments as Values. For each
-// statement operands, represented as Value, lookup its Value conterpart in
+// instruction operands, represented as Value, lookup its Value conterpart in
// the valueRemapping table.
static llvm::SmallVector<mlir::Value *, 4>
-operandsAs(Statement *opStmt,
+operandsAs(Instruction *opInst,
const llvm::DenseMap<const Value *, Value *> &valueRemapping) {
llvm::SmallVector<Value *, 4> operands;
- for (const Value *operand : opStmt->getOperands()) {
+ for (const Value *operand : opInst->getOperands()) {
assert(valueRemapping.count(operand) != 0 && "operand is not defined");
operands.push_back(valueRemapping.lookup(operand));
}
return operands;
}
-// Convert an operation statement into an operation instruction.
+// Convert an operation instruction into an operation instruction.
//
// The operation description (name, number and types of operands or results)
// remains the same but the values must be updated to be Values. Update the
// mapping Value->Value as the conversion is performed. The operation
// instruction is appended to current block (end of SESE region).
-void FunctionConverter::visitOperationInst(OperationInst *opStmt) {
+void FunctionConverter::visitOperationInst(OperationInst *opInst) {
// Set up basic operation state (context, name, operands).
- OperationState state(cfgFunc->getContext(), opStmt->getLoc(),
- opStmt->getName());
- state.addOperands(operandsAs(opStmt, valueRemapping));
+ OperationState state(cfgFunc->getContext(), opInst->getLoc(),
+ opInst->getName());
+ state.addOperands(operandsAs(opInst, valueRemapping));
// Set up operation return types. The corresponding Values will become
// available after the operation is created.
state.addTypes(functional::map(
- [](Value *result) { return result->getType(); }, opStmt->getResults()));
+ [](Value *result) { return result->getType(); }, opInst->getResults()));
// Copy attributes.
- for (auto attr : opStmt->getAttrs()) {
+ for (auto attr : opInst->getAttrs()) {
state.addAttribute(attr.first.strref(), attr.second);
}
- auto opInst = builder.createOperation(state);
+ auto op = builder.createOperation(state);
// Make results of the operation accessible to the following operations
// through remapping.
- assert(opInst->getNumResults() == opStmt->getNumResults());
+ assert(opInst->getNumResults() == op->getNumResults());
for (unsigned i = 0, n = opInst->getNumResults(); i < n; ++i) {
valueRemapping.insert(
- std::make_pair(opStmt->getResult(i), opInst->getResult(i)));
+ std::make_pair(opInst->getResult(i), op->getResult(i)));
}
}
@@ -116,10 +116,10 @@ Value *FunctionConverter::getConstantIndexValue(int64_t value) {
return op->getResult();
}
-// Visit all statements in the given statement block.
+// Visit all instructions in the given instruction block.
void FunctionConverter::visitBlock(Block *Block) {
- for (auto &stmt : *Block)
- this->visit(&stmt);
+ for (auto &inst : *Block)
+ this->visit(&inst);
}
// Given a range of values, emit the code that reduces them with "min" or "max"
@@ -211,7 +211,7 @@ Value *FunctionConverter::buildMinMaxReductionSeq(
// | <new insertion point> |
// +--------------------------------+
//
-void FunctionConverter::visitForStmt(ForStmt *forStmt) {
+void FunctionConverter::visitForInst(ForInst *forInst) {
// First, store the loop insertion location so that we can go back to it after
// creating the new blocks (block creation updates the insertion point).
Block *loopInsertionPoint = builder.getInsertionBlock();
@@ -228,27 +228,27 @@ void FunctionConverter::visitForStmt(ForStmt *forStmt) {
// The loop condition block has an argument for loop induction variable.
// Create it upfront and make the loop induction variable -> basic block
- // argument remapping available to the following instructions. ForStatement
+ // argument remapping available to the following instructions. ForInstruction
// is-a Value corresponding to the loop induction variable.
builder.setInsertionPointToEnd(loopConditionBlock);
Value *iv = loopConditionBlock->addArgument(builder.getIndexType());
- valueRemapping.insert(std::make_pair(forStmt, iv));
+ valueRemapping.insert(std::make_pair(forInst, iv));
// Recursively construct loop body region.
// Walking manually because we need custom logic before and after traversing
// the list of children.
builder.setInsertionPointToEnd(loopBodyFirstBlock);
- visitBlock(forStmt->getBody());
+ visitBlock(forInst->getBody());
// Builder point is currently at the last block of the loop body. Append the
// induction variable stepping to this block and branch back to the exit
// condition block. Construct an affine map f : (x -> x+step) and apply this
// map to the induction variable.
- auto affStep = builder.getAffineConstantExpr(forStmt->getStep());
+ auto affStep = builder.getAffineConstantExpr(forInst->getStep());
auto affDim = builder.getAffineDimExpr(0);
auto affStepMap = builder.getAffineMap(1, 0, {affDim + affStep}, {});
auto stepOp =
- builder.create<AffineApplyOp>(forStmt->getLoc(), affStepMap, iv);
+ builder.create<AffineApplyOp>(forInst->getLoc(), affStepMap, iv);
Value *nextIvValue = stepOp->getResult(0);
builder.create<BranchOp>(builder.getUnknownLoc(), loopConditionBlock,
nextIvValue);
@@ -262,22 +262,22 @@ void FunctionConverter::visitForStmt(ForStmt *forStmt) {
return valueRemapping.lookup(value);
};
auto operands =
- functional::map(remapOperands, forStmt->getLowerBoundOperands());
+ functional::map(remapOperands, forInst->getLowerBoundOperands());
auto lbAffineApply = builder.create<AffineApplyOp>(
- forStmt->getLoc(), forStmt->getLowerBoundMap(), operands);
+ forInst->getLoc(), forInst->getLowerBoundMap(), operands);
Value *lowerBound = buildMinMaxReductionSeq(
- forStmt->getLoc(), CmpIPredicate::SGT, lbAffineApply->getResults());
- operands = functional::map(remapOperands, forStmt->getUpperBoundOperands());
+ forInst->getLoc(), CmpIPredicate::SGT, lbAffineApply->getResults());
+ operands = functional::map(remapOperands, forInst->getUpperBoundOperands());
auto ubAffineApply = builder.create<AffineApplyOp>(
- forStmt->getLoc(), forStmt->getUpperBoundMap(), operands);
+ forInst->getLoc(), forInst->getUpperBoundMap(), operands);
Value *upperBound = buildMinMaxReductionSeq(
- forStmt->getLoc(), CmpIPredicate::SLT, ubAffineApply->getResults());
+ forInst->getLoc(), CmpIPredicate::SLT, ubAffineApply->getResults());
builder.create<BranchOp>(builder.getUnknownLoc(), loopConditionBlock,
lowerBound);
builder.setInsertionPointToEnd(loopConditionBlock);
auto comparisonOp = builder.create<CmpIOp>(
- forStmt->getLoc(), CmpIPredicate::SLT, iv, upperBound);
+ forInst->getLoc(), CmpIPredicate::SLT, iv, upperBound);
auto comparisonResult = comparisonOp->getResult();
builder.create<CondBranchOp>(builder.getUnknownLoc(), comparisonResult,
loopBodyFirstBlock, ArrayRef<Value *>(),
@@ -288,16 +288,16 @@ void FunctionConverter::visitForStmt(ForStmt *forStmt) {
builder.setInsertionPointToEnd(postLoopBlock);
}
-// Convert an "if" statement into a flow of basic blocks.
+// Convert an "if" instruction into a flow of basic blocks.
//
-// Create an SESE region for the if statement (including its "then" and optional
-// "else" statement blocks) and append it to the end of the current region. The
-// conditional region consists of a sequence of condition-checking blocks that
-// implement the short-circuit scheme, followed by a "then" SESE region and an
-// "else" SESE region, and the continuation block that post-dominates all blocks
-// of the "if" statement. The flow of blocks that correspond to the "then" and
-// "else" clauses are constructed recursively, enabling easy nesting of "if"
-// statements and if-then-else-if chains.
+// Create an SESE region for the if instruction (including its "then" and
+// optional "else" instruction blocks) and append it to the end of the current
+// region. The conditional region consists of a sequence of condition-checking
+// blocks that implement the short-circuit scheme, followed by a "then" SESE
+// region and an "else" SESE region, and the continuation block that
+// post-dominates all blocks of the "if" instruction. The flow of blocks that
+// correspond to the "then" and "else" clauses are constructed recursively,
+// enabling easy nesting of "if" instructions and if-then-else-if chains.
//
// +--------------------------------+
// | <end of current SESE region> |
@@ -365,17 +365,17 @@ void FunctionConverter::visitForStmt(ForStmt *forStmt) {
// | <new insertion point> |
// +--------------------------------+
//
-void FunctionConverter::visitIfStmt(IfStmt *ifStmt) {
- assert(ifStmt != nullptr);
+void FunctionConverter::visitIfInst(IfInst *ifInst) {
+ assert(ifInst != nullptr);
- auto integerSet = ifStmt->getCondition().getIntegerSet();
+ auto integerSet = ifInst->getCondition().getIntegerSet();
// Create basic blocks for the 'then' block and for the 'else' block.
// Although 'else' block may be empty in absence of an 'else' clause, create
// it anyway for the sake of consistency and output IR readability. Also
// create extra blocks for condition checking to prepare for short-circuit
- // logic: conditions in the 'if' statement are conjunctive, so we can jump to
- // the false branch as soon as one condition fails. `cond_br` requires
+ // logic: conditions in the 'if' instruction are conjunctive, so we can jump
+ // to the false branch as soon as one condition fails. `cond_br` requires
// another block as a target when the condition is true, and that block will
// contain the next condition.
Block *ifInsertionBlock = builder.getInsertionBlock();
@@ -412,14 +412,14 @@ void FunctionConverter::visitIfStmt(IfStmt *ifStmt) {
builder.getAffineMap(integerSet.getNumDims(),
integerSet.getNumSymbols(), constraintExpr, {});
auto affineApplyOp = builder.create<AffineApplyOp>(
- ifStmt->getLoc(), affineMap, operandsAs(ifStmt, valueRemapping));
+ ifInst->getLoc(), affineMap, operandsAs(ifInst, valueRemapping));
Value *affResult = affineApplyOp->getResult(0);
// Compare the result of the apply and branch.
auto comparisonOp = builder.create<CmpIOp>(
- ifStmt->getLoc(), isEquality ? CmpIPredicate::EQ : CmpIPredicate::SGE,
+ ifInst->getLoc(), isEquality ? CmpIPredicate::EQ : CmpIPredicate::SGE,
affResult, zeroConstant);
- builder.create<CondBranchOp>(ifStmt->getLoc(), comparisonOp->getResult(),
+ builder.create<CondBranchOp>(ifInst->getLoc(), comparisonOp->getResult(),
nextBlock, /*trueArgs*/ ArrayRef<Value *>(),
elseBlock,
/*falseArgs*/ ArrayRef<Value *>());
@@ -429,13 +429,13 @@ void FunctionConverter::visitIfStmt(IfStmt *ifStmt) {
// Recursively traverse the 'then' block.
builder.setInsertionPointToEnd(thenBlock);
- visitBlock(ifStmt->getThen());
+ visitBlock(ifInst->getThen());
Block *lastThenBlock = builder.getInsertionBlock();
// Recursively traverse the 'else' block if present.
builder.setInsertionPointToEnd(elseBlock);
- if (ifStmt->hasElse())
- visitBlock(ifStmt->getElse());
+ if (ifInst->hasElse())
+ visitBlock(ifInst->getElse());
Block *lastElseBlock = builder.getInsertionBlock();
// Create the continuation block here so that it appears lexically after the
@@ -443,9 +443,9 @@ void FunctionConverter::visitIfStmt(IfStmt *ifStmt) {
// to the continuation block.
Block *continuationBlock = builder.createBlock();
builder.setInsertionPointToEnd(lastThenBlock);
- builder.create<BranchOp>(ifStmt->getLoc(), continuationBlock);
+ builder.create<BranchOp>(ifInst->getLoc(), continuationBlock);
builder.setInsertionPointToEnd(lastElseBlock);
- builder.create<BranchOp>(ifStmt->getLoc(), continuationBlock);
+ builder.create<BranchOp>(ifInst->getLoc(), continuationBlock);
// Make sure building can continue by setting up the continuation block as the
// insertion point.
@@ -454,12 +454,12 @@ void FunctionConverter::visitIfStmt(IfStmt *ifStmt) {
// Entry point of the function convertor.
//
-// Conversion is performed by recursively visiting statements of a Function.
+// Conversion is performed by recursively visiting instructions of a Function.
// It reasons in terms of single-entry single-exit (SESE) regions that are not
// materialized in the code. Instead, the pointer to the last block of the
// region is maintained throughout the conversion as the insertion point of the
// IR builder since we never change the first block after its creation. "Block"
-// statements such as loops and branches create new SESE regions for their
+// instructions such as loops and branches create new SESE regions for their
// bodies, and surround them with additional basic blocks for the control flow.
// Individual operations are simply appended to the end of the last basic block
// of the current region. The SESE invariant allows us to easily handle nested
@@ -484,9 +484,9 @@ Function *FunctionConverter::convert(Function *mlFunc) {
valueRemapping.insert(std::make_pair(mlArgument, cfgArgument));
}
- // Convert statements in order.
- for (auto &stmt : *mlFunc->getBody()) {
- visit(&stmt);
+ // Convert instructions in order.
+ for (auto &inst : *mlFunc->getBody()) {
+ visit(&inst);
}
return cfgFunc;
diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp
index 69344819ed8..bc7f31f0434 100644
--- a/mlir/lib/Transforms/DmaGeneration.cpp
+++ b/mlir/lib/Transforms/DmaGeneration.cpp
@@ -25,7 +25,7 @@
#include "mlir/Analysis/Utils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/StmtVisitor.h"
+#include "mlir/IR/InstVisitor.h"
#include "mlir/Pass.h"
#include "mlir/StandardOps/StandardOps.h"
#include "mlir/Transforms/Passes.h"
@@ -49,7 +49,7 @@ namespace {
/// buffers in 'fastMemorySpace', and replaces memory operations to the former
/// by the latter. Only load op's handled for now.
/// TODO(bondhugula): extend this to store op's.
-struct DmaGeneration : public FunctionPass, StmtWalker<DmaGeneration> {
+struct DmaGeneration : public FunctionPass, InstWalker<DmaGeneration> {
explicit DmaGeneration(unsigned slowMemorySpace = 0,
unsigned fastMemorySpaceArg = 1,
int minDmaTransferSize = 1024)
@@ -65,10 +65,10 @@ struct DmaGeneration : public FunctionPass, StmtWalker<DmaGeneration> {
// Not applicable to CFG functions.
PassResult runOnCFGFunction(Function *f) override { return success(); }
PassResult runOnMLFunction(Function *f) override;
- void runOnForStmt(ForStmt *forStmt);
+ void runOnForInst(ForInst *forInst);
- void visitOperationInst(OperationInst *opStmt);
- bool generateDma(const MemRefRegion &region, ForStmt *forStmt,
+ void visitOperationInst(OperationInst *opInst);
+ bool generateDma(const MemRefRegion &region, ForInst *forInst,
uint64_t *sizeInBytes);
// List of memory regions to DMA for.
@@ -108,11 +108,11 @@ FunctionPass *mlir::createDmaGenerationPass(unsigned slowMemorySpace,
// Gather regions to promote to buffers in faster memory space.
// TODO(bondhugula): handle store op's; only load's handled for now.
-void DmaGeneration::visitOperationInst(OperationInst *opStmt) {
- if (auto loadOp = opStmt->dyn_cast<LoadOp>()) {
+void DmaGeneration::visitOperationInst(OperationInst *opInst) {
+ if (auto loadOp = opInst->dyn_cast<LoadOp>()) {
if (loadOp->getMemRefType().getMemorySpace() != slowMemorySpace)
return;
- } else if (auto storeOp = opStmt->dyn_cast<StoreOp>()) {
+ } else if (auto storeOp = opInst->dyn_cast<StoreOp>()) {
if (storeOp->getMemRefType().getMemorySpace() != slowMemorySpace)
return;
} else {
@@ -125,7 +125,7 @@ void DmaGeneration::visitOperationInst(OperationInst *opStmt) {
// This way we would be allocating O(num of memref's) sets instead of
// O(num of load/store op's).
auto region = std::make_unique<MemRefRegion>();
- if (!getMemRefRegion(opStmt, dmaDepth, region.get())) {
+ if (!getMemRefRegion(opInst, dmaDepth, region.get())) {
LLVM_DEBUG(llvm::dbgs() << "Error obtaining memory region\n");
return;
}
@@ -170,19 +170,19 @@ static void getMultiLevelStrides(const MemRefRegion &region,
// Creates a buffer in the faster memory space for the specified region;
// generates a DMA from the lower memory space to this one, and replaces all
// loads to load from that buffer. Returns true if DMAs are generated.
-bool DmaGeneration::generateDma(const MemRefRegion &region, ForStmt *forStmt,
+bool DmaGeneration::generateDma(const MemRefRegion &region, ForInst *forInst,
uint64_t *sizeInBytes) {
// DMAs for read regions are going to be inserted just before the for loop.
- FuncBuilder prologue(forStmt);
+ FuncBuilder prologue(forInst);
// DMAs for write regions are going to be inserted just after the for loop.
- FuncBuilder epilogue(forStmt->getBlock(),
- std::next(Block::iterator(forStmt)));
+ FuncBuilder epilogue(forInst->getBlock(),
+ std::next(Block::iterator(forInst)));
FuncBuilder *b = region.isWrite() ? &epilogue : &prologue;
// Builder to create constants at the top level.
- FuncBuilder top(forStmt->getFunction());
+ FuncBuilder top(forInst->getFunction());
- auto loc = forStmt->getLoc();
+ auto loc = forInst->getLoc();
auto *memref = region.memref;
auto memRefType = memref->getType().cast<MemRefType>();
@@ -285,7 +285,7 @@ bool DmaGeneration::generateDma(const MemRefRegion &region, ForStmt *forStmt,
LLVM_DEBUG(llvm::dbgs() << "Creating a new buffer of type: ");
LLVM_DEBUG(fastMemRefType.dump(); llvm::dbgs() << "\n");
- // Create the fast memory space buffer just before the 'for' statement.
+ // Create the fast memory space buffer just before the 'for' instruction.
fastMemRef = prologue.create<AllocOp>(loc, fastMemRefType)->getResult();
// Record it.
fastBufferMap[memref] = fastMemRef;
@@ -361,58 +361,58 @@ bool DmaGeneration::generateDma(const MemRefRegion &region, ForStmt *forStmt,
remapExprs.push_back(dimExpr - offsets[i]);
}
auto indexRemap = b->getAffineMap(outerIVs.size() + rank, 0, remapExprs, {});
- // *Only* those uses within the body of 'forStmt' are replaced.
+ // *Only* those uses within the body of 'forInst' are replaced.
replaceAllMemRefUsesWith(memref, fastMemRef,
/*extraIndices=*/{}, indexRemap,
/*extraOperands=*/outerIVs,
- /*domStmtFilter=*/&*forStmt->getBody()->begin());
+ /*domInstFilter=*/&*forInst->getBody()->begin());
return true;
}
-/// Returns the nesting depth of this statement, i.e., the number of loops
-/// surrounding this statement.
+/// Returns the nesting depth of this instruction, i.e., the number of loops
+/// surrounding this instruction.
// TODO(bondhugula): move this to utilities later.
-static unsigned getNestingDepth(const Statement &stmt) {
- const Statement *currStmt = &stmt;
+static unsigned getNestingDepth(const Instruction &inst) {
+ const Instruction *currInst = &inst;
unsigned depth = 0;
- while ((currStmt = currStmt->getParentStmt())) {
- if (isa<ForStmt>(currStmt))
+ while ((currInst = currInst->getParentInst())) {
+ if (isa<ForInst>(currInst))
depth++;
}
return depth;
}
-// TODO(bondhugula): make this run on a Block instead of a 'for' stmt.
-void DmaGeneration::runOnForStmt(ForStmt *forStmt) {
+// TODO(bondhugula): make this run on a Block instead of a 'for' inst.
+void DmaGeneration::runOnForInst(ForInst *forInst) {
// For now (for testing purposes), we'll run this on the outermost among 'for'
- // stmt's with unit stride, i.e., right at the top of the tile if tiling has
+ // inst's with unit stride, i.e., right at the top of the tile if tiling has
// been done. In the future, the DMA generation has to be done at a level
// where the generated data fits in a higher level of the memory hierarchy; so
// the pass has to be instantiated with additional information that we aren't
// provided with at the moment.
- if (forStmt->getStep() != 1) {
- if (auto *innerFor = dyn_cast<ForStmt>(&*forStmt->getBody()->begin())) {
- runOnForStmt(innerFor);
+ if (forInst->getStep() != 1) {
+ if (auto *innerFor = dyn_cast<ForInst>(&*forInst->getBody()->begin())) {
+ runOnForInst(innerFor);
}
return;
}
// DMAs will be generated for this depth, i.e., for all data accessed by this
// loop.
- dmaDepth = getNestingDepth(*forStmt);
+ dmaDepth = getNestingDepth(*forInst);
regions.clear();
fastBufferMap.clear();
- // Walk this 'for' statement to gather all memory regions.
- walk(forStmt);
+ // Walk this 'for' instruction to gather all memory regions.
+ walk(forInst);
uint64_t totalSizeInBytes = 0;
bool ret = false;
for (const auto &region : regions) {
uint64_t sizeInBytes;
- bool iRet = generateDma(*region, forStmt, &sizeInBytes);
+ bool iRet = generateDma(*region, forInst, &sizeInBytes);
if (iRet)
totalSizeInBytes += sizeInBytes;
ret = ret | iRet;
@@ -426,9 +426,9 @@ void DmaGeneration::runOnForStmt(ForStmt *forStmt) {
}
PassResult DmaGeneration::runOnMLFunction(Function *f) {
- for (auto &stmt : *f->getBody()) {
- if (auto *forStmt = dyn_cast<ForStmt>(&stmt)) {
- runOnForStmt(forStmt);
+ for (auto &inst : *f->getBody()) {
+ if (auto *forInst = dyn_cast<ForInst>(&inst)) {
+ runOnForInst(forInst);
}
}
// This function never leaves the IR in an invalid state.
diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp
index d31337437ad..97dea753f88 100644
--- a/mlir/lib/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Transforms/LoopFusion.cpp
@@ -27,7 +27,7 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/StmtVisitor.h"
+#include "mlir/IR/InstVisitor.h"
#include "mlir/Pass.h"
#include "mlir/StandardOps/StandardOps.h"
#include "mlir/Transforms/LoopUtils.h"
@@ -80,20 +80,20 @@ char LoopFusion::passID = 0;
FunctionPass *mlir::createLoopFusionPass() { return new LoopFusion; }
-static void getSingleMemRefAccess(OperationInst *loadOrStoreOpStmt,
+static void getSingleMemRefAccess(OperationInst *loadOrStoreOpInst,
MemRefAccess *access) {
- if (auto loadOp = loadOrStoreOpStmt->dyn_cast<LoadOp>()) {
+ if (auto loadOp = loadOrStoreOpInst->dyn_cast<LoadOp>()) {
access->memref = loadOp->getMemRef();
- access->opStmt = loadOrStoreOpStmt;
+ access->opInst = loadOrStoreOpInst;
auto loadMemrefType = loadOp->getMemRefType();
access->indices.reserve(loadMemrefType.getRank());
for (auto *index : loadOp->getIndices()) {
access->indices.push_back(index);
}
} else {
- assert(loadOrStoreOpStmt->isa<StoreOp>());
- auto storeOp = loadOrStoreOpStmt->dyn_cast<StoreOp>();
- access->opStmt = loadOrStoreOpStmt;
+ assert(loadOrStoreOpInst->isa<StoreOp>());
+ auto storeOp = loadOrStoreOpInst->dyn_cast<StoreOp>();
+ access->opInst = loadOrStoreOpInst;
access->memref = storeOp->getMemRef();
auto storeMemrefType = storeOp->getMemRefType();
access->indices.reserve(storeMemrefType.getRank());
@@ -112,24 +112,24 @@ struct FusionCandidate {
MemRefAccess dstAccess;
};
-static FusionCandidate buildFusionCandidate(OperationInst *srcStoreOpStmt,
- OperationInst *dstLoadOpStmt) {
+static FusionCandidate buildFusionCandidate(OperationInst *srcStoreOpInst,
+ OperationInst *dstLoadOpInst) {
FusionCandidate candidate;
// Get store access for src loop nest.
- getSingleMemRefAccess(srcStoreOpStmt, &candidate.srcAccess);
+ getSingleMemRefAccess(srcStoreOpInst, &candidate.srcAccess);
// Get load access for dst loop nest.
- getSingleMemRefAccess(dstLoadOpStmt, &candidate.dstAccess);
+ getSingleMemRefAccess(dstLoadOpInst, &candidate.dstAccess);
return candidate;
}
-// Returns the loop depth of the loop nest surrounding 'opStmt'.
-static unsigned getLoopDepth(OperationInst *opStmt) {
+// Returns the loop depth of the loop nest surrounding 'opInst'.
+static unsigned getLoopDepth(OperationInst *opInst) {
unsigned loopDepth = 0;
- auto *currStmt = opStmt->getParentStmt();
- ForStmt *currForStmt;
- while (currStmt && (currForStmt = dyn_cast<ForStmt>(currStmt))) {
+ auto *currInst = opInst->getParentInst();
+ ForInst *currForInst;
+ while (currInst && (currForInst = dyn_cast<ForInst>(currInst))) {
++loopDepth;
- currStmt = currStmt->getParentStmt();
+ currInst = currInst->getParentInst();
}
return loopDepth;
}
@@ -137,28 +137,28 @@ static unsigned getLoopDepth(OperationInst *opStmt) {
namespace {
// LoopNestStateCollector walks loop nests and collects load and store
-// operations, and whether or not an IfStmt was encountered in the loop nest.
-class LoopNestStateCollector : public StmtWalker<LoopNestStateCollector> {
+// operations, and whether or not an IfInst was encountered in the loop nest.
+class LoopNestStateCollector : public InstWalker<LoopNestStateCollector> {
public:
- SmallVector<ForStmt *, 4> forStmts;
- SmallVector<OperationInst *, 4> loadOpStmts;
- SmallVector<OperationInst *, 4> storeOpStmts;
- bool hasIfStmt = false;
+ SmallVector<ForInst *, 4> forInsts;
+ SmallVector<OperationInst *, 4> loadOpInsts;
+ SmallVector<OperationInst *, 4> storeOpInsts;
+ bool hasIfInst = false;
- void visitForStmt(ForStmt *forStmt) { forStmts.push_back(forStmt); }
+ void visitForInst(ForInst *forInst) { forInsts.push_back(forInst); }
- void visitIfStmt(IfStmt *ifStmt) { hasIfStmt = true; }
+ void visitIfInst(IfInst *ifInst) { hasIfInst = true; }
- void visitOperationInst(OperationInst *opStmt) {
- if (opStmt->isa<LoadOp>())
- loadOpStmts.push_back(opStmt);
- if (opStmt->isa<StoreOp>())
- storeOpStmts.push_back(opStmt);
+ void visitOperationInst(OperationInst *opInst) {
+ if (opInst->isa<LoadOp>())
+ loadOpInsts.push_back(opInst);
+ if (opInst->isa<StoreOp>())
+ storeOpInsts.push_back(opInst);
}
};
// MemRefDependenceGraph is a graph data structure where graph nodes are
-// top-level statements in a Function which contain load/store ops, and edges
+// top-level instructions in a Function which contain load/store ops, and edges
// are memref dependences between the nodes.
// TODO(andydavis) Add a depth parameter to dependence graph construction.
struct MemRefDependenceGraph {
@@ -170,18 +170,18 @@ public:
// The unique identifier of this node in the graph.
unsigned id;
// The top-level statment which is (or contains) loads/stores.
- Statement *stmt;
+ Instruction *inst;
// List of load operations.
SmallVector<OperationInst *, 4> loads;
- // List of store op stmts.
+ // List of store op insts.
SmallVector<OperationInst *, 4> stores;
- Node(unsigned id, Statement *stmt) : id(id), stmt(stmt) {}
+ Node(unsigned id, Instruction *inst) : id(id), inst(inst) {}
// Returns the load op count for 'memref'.
unsigned getLoadOpCount(Value *memref) {
unsigned loadOpCount = 0;
- for (auto *loadOpStmt : loads) {
- if (memref == loadOpStmt->cast<LoadOp>()->getMemRef())
+ for (auto *loadOpInst : loads) {
+ if (memref == loadOpInst->cast<LoadOp>()->getMemRef())
++loadOpCount;
}
return loadOpCount;
@@ -190,8 +190,8 @@ public:
// Returns the store op count for 'memref'.
unsigned getStoreOpCount(Value *memref) {
unsigned storeOpCount = 0;
- for (auto *storeOpStmt : stores) {
- if (memref == storeOpStmt->cast<StoreOp>()->getMemRef())
+ for (auto *storeOpInst : stores) {
+ if (memref == storeOpInst->cast<StoreOp>()->getMemRef())
++storeOpCount;
}
return storeOpCount;
@@ -315,10 +315,10 @@ public:
void addToNode(unsigned id, const SmallVectorImpl<OperationInst *> &loads,
const SmallVectorImpl<OperationInst *> &stores) {
Node *node = getNode(id);
- for (auto *loadOpStmt : loads)
- node->loads.push_back(loadOpStmt);
- for (auto *storeOpStmt : stores)
- node->stores.push_back(storeOpStmt);
+ for (auto *loadOpInst : loads)
+ node->loads.push_back(loadOpInst);
+ for (auto *storeOpInst : stores)
+ node->stores.push_back(storeOpInst);
}
void print(raw_ostream &os) const {
@@ -341,55 +341,55 @@ public:
void dump() const { print(llvm::errs()); }
};
-// Intializes the data dependence graph by walking statements in 'f'.
+// Intializes the data dependence graph by walking instructions in 'f'.
// Assigns each node in the graph a node id based on program order in 'f'.
// TODO(andydavis) Add support for taking a Block arg to construct the
// dependence graph at a different depth.
bool MemRefDependenceGraph::init(Function *f) {
unsigned id = 0;
DenseMap<Value *, SetVector<unsigned>> memrefAccesses;
- for (auto &stmt : *f->getBody()) {
- if (auto *forStmt = dyn_cast<ForStmt>(&stmt)) {
- // Create graph node 'id' to represent top-level 'forStmt' and record
+ for (auto &inst : *f->getBody()) {
+ if (auto *forInst = dyn_cast<ForInst>(&inst)) {
+ // Create graph node 'id' to represent top-level 'forInst' and record
// all loads and store accesses it contains.
LoopNestStateCollector collector;
- collector.walkForStmt(forStmt);
- // Return false if IfStmts are found (not currently supported).
- if (collector.hasIfStmt)
+ collector.walkForInst(forInst);
+ // Return false if IfInsts are found (not currently supported).
+ if (collector.hasIfInst)
return false;
- Node node(id++, &stmt);
- for (auto *opStmt : collector.loadOpStmts) {
- node.loads.push_back(opStmt);
- auto *memref = opStmt->cast<LoadOp>()->getMemRef();
+ Node node(id++, &inst);
+ for (auto *opInst : collector.loadOpInsts) {
+ node.loads.push_back(opInst);
+ auto *memref = opInst->cast<LoadOp>()->getMemRef();
memrefAccesses[memref].insert(node.id);
}
- for (auto *opStmt : collector.storeOpStmts) {
- node.stores.push_back(opStmt);
- auto *memref = opStmt->cast<StoreOp>()->getMemRef();
+ for (auto *opInst : collector.storeOpInsts) {
+ node.stores.push_back(opInst);
+ auto *memref = opInst->cast<StoreOp>()->getMemRef();
memrefAccesses[memref].insert(node.id);
}
nodes.insert({node.id, node});
}
- if (auto *opStmt = dyn_cast<OperationInst>(&stmt)) {
- if (auto loadOp = opStmt->dyn_cast<LoadOp>()) {
+ if (auto *opInst = dyn_cast<OperationInst>(&inst)) {
+ if (auto loadOp = opInst->dyn_cast<LoadOp>()) {
// Create graph node for top-level load op.
- Node node(id++, &stmt);
- node.loads.push_back(opStmt);
- auto *memref = opStmt->cast<LoadOp>()->getMemRef();
+ Node node(id++, &inst);
+ node.loads.push_back(opInst);
+ auto *memref = opInst->cast<LoadOp>()->getMemRef();
memrefAccesses[memref].insert(node.id);
nodes.insert({node.id, node});
}
- if (auto storeOp = opStmt->dyn_cast<StoreOp>()) {
+ if (auto storeOp = opInst->dyn_cast<StoreOp>()) {
// Create graph node for top-level store op.
- Node node(id++, &stmt);
- node.stores.push_back(opStmt);
- auto *memref = opStmt->cast<StoreOp>()->getMemRef();
+ Node node(id++, &inst);
+ node.stores.push_back(opInst);
+ auto *memref = opInst->cast<StoreOp>()->getMemRef();
memrefAccesses[memref].insert(node.id);
nodes.insert({node.id, node});
}
}
- // Return false if IfStmts are found (not currently supported).
- if (isa<IfStmt>(&stmt))
+ // Return false if IfInsts are found (not currently supported).
+ if (isa<IfInst>(&inst))
return false;
}
@@ -421,9 +421,9 @@ bool MemRefDependenceGraph::init(Function *f) {
//
// *) A worklist is initialized with node ids from the dependence graph.
// *) For each node id in the worklist:
-// *) Pop a ForStmt of the worklist. This 'dstForStmt' will be a candidate
-// destination ForStmt into which fusion will be attempted.
-// *) Add each LoadOp currently in 'dstForStmt' into list 'dstLoadOps'.
+// *) Pop a ForInst of the worklist. This 'dstForInst' will be a candidate
+// destination ForInst into which fusion will be attempted.
+// *) Add each LoadOp currently in 'dstForInst' into list 'dstLoadOps'.
// *) For each LoadOp in 'dstLoadOps' do:
// *) Lookup dependent loop nests at earlier positions in the Function
// which have a single store op to the same memref.
@@ -434,12 +434,12 @@ bool MemRefDependenceGraph::init(Function *f) {
// bounds to be functions of 'dstLoopNest' IVs and symbols.
// *) Fuse the 'srcLoopNest' computation slice into the 'dstLoopNest',
// just before the dst load op user.
-// *) Add the newly fused load/store operation statements to the state,
+// *) Add the newly fused load/store operation instructions to the state,
// and also add newly fuse load ops to 'dstLoopOps' to be considered
// as fusion dst load ops in another iteration.
// *) Remove old src loop nest and its associated state.
//
-// Given a graph where top-level statements are vertices in the set 'V' and
+// Given a graph where top-level instructions are vertices in the set 'V' and
// edges in the set 'E' are dependences between vertices, this algorithm
// takes O(V) time for initialization, and has runtime O(V + E).
//
@@ -471,14 +471,14 @@ public:
// Get 'dstNode' into which to attempt fusion.
auto *dstNode = mdg->getNode(dstId);
// Skip if 'dstNode' is not a loop nest.
- if (!isa<ForStmt>(dstNode->stmt))
+ if (!isa<ForInst>(dstNode->inst))
continue;
SmallVector<OperationInst *, 4> loads = dstNode->loads;
while (!loads.empty()) {
- auto *dstLoadOpStmt = loads.pop_back_val();
- auto *memref = dstLoadOpStmt->cast<LoadOp>()->getMemRef();
- // Skip 'dstLoadOpStmt' if multiple loads to 'memref' in 'dstNode'.
+ auto *dstLoadOpInst = loads.pop_back_val();
+ auto *memref = dstLoadOpInst->cast<LoadOp>()->getMemRef();
+ // Skip 'dstLoadOpInst' if multiple loads to 'memref' in 'dstNode'.
if (dstNode->getLoadOpCount(memref) != 1)
continue;
// Skip if no input edges along which to fuse.
@@ -491,7 +491,7 @@ public:
continue;
auto *srcNode = mdg->getNode(srcEdge.id);
// Skip if 'srcNode' is not a loop nest.
- if (!isa<ForStmt>(srcNode->stmt))
+ if (!isa<ForInst>(srcNode->inst))
continue;
// Skip if 'srcNode' has more than one store to 'memref'.
if (srcNode->getStoreOpCount(memref) != 1)
@@ -508,17 +508,17 @@ public:
if (mdg->getMinOutEdgeNodeId(srcNode->id) != dstId)
continue;
// Get unique 'srcNode' store op.
- auto *srcStoreOpStmt = srcNode->stores.front();
- // Build fusion candidate out of 'srcStoreOpStmt' and 'dstLoadOpStmt'.
+ auto *srcStoreOpInst = srcNode->stores.front();
+ // Build fusion candidate out of 'srcStoreOpInst' and 'dstLoadOpInst'.
FusionCandidate candidate =
- buildFusionCandidate(srcStoreOpStmt, dstLoadOpStmt);
+ buildFusionCandidate(srcStoreOpInst, dstLoadOpInst);
// Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'.
unsigned srcLoopDepth = clSrcLoopDepth.getNumOccurrences() > 0
? clSrcLoopDepth
- : getLoopDepth(srcStoreOpStmt);
+ : getLoopDepth(srcStoreOpInst);
unsigned dstLoopDepth = clDstLoopDepth.getNumOccurrences() > 0
? clDstLoopDepth
- : getLoopDepth(dstLoadOpStmt);
+ : getLoopDepth(dstLoadOpInst);
auto *sliceLoopNest = mlir::insertBackwardComputationSlice(
&candidate.srcAccess, &candidate.dstAccess, srcLoopDepth,
dstLoopDepth);
@@ -527,19 +527,19 @@ public:
mdg->updateEdgesAndRemoveSrcNode(srcNode->id, dstNode->id);
// Record all load/store accesses in 'sliceLoopNest' at 'dstPos'.
LoopNestStateCollector collector;
- collector.walkForStmt(sliceLoopNest);
- mdg->addToNode(dstId, collector.loadOpStmts,
- collector.storeOpStmts);
+ collector.walkForInst(sliceLoopNest);
+ mdg->addToNode(dstId, collector.loadOpInsts,
+ collector.storeOpInsts);
// Add new load ops to current Node load op list 'loads' to
// continue fusing based on new operands.
- for (auto *loadOpStmt : collector.loadOpStmts)
- loads.push_back(loadOpStmt);
+ for (auto *loadOpInst : collector.loadOpInsts)
+ loads.push_back(loadOpInst);
// Promote single iteration loops to single IV value.
- for (auto *forStmt : collector.forStmts) {
- promoteIfSingleIteration(forStmt);
+ for (auto *forInst : collector.forInsts) {
+ promoteIfSingleIteration(forInst);
}
// Remove old src loop nest.
- cast<ForStmt>(srcNode->stmt)->erase();
+ cast<ForInst>(srcNode->inst)->erase();
}
}
}
diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp
index 109953f2296..8f3be8a3d45 100644
--- a/mlir/lib/Transforms/LoopTiling.cpp
+++ b/mlir/lib/Transforms/LoopTiling.cpp
@@ -55,16 +55,16 @@ char LoopTiling::passID = 0;
/// Function.
FunctionPass *mlir::createLoopTilingPass() { return new LoopTiling(); }
-// Move the loop body of ForStmt 'src' from 'src' into the specified location in
+// Move the loop body of ForInst 'src' from 'src' into the specified location in
// destination's body.
-static inline void moveLoopBody(ForStmt *src, ForStmt *dest,
+static inline void moveLoopBody(ForInst *src, ForInst *dest,
Block::iterator loc) {
dest->getBody()->getInstructions().splice(loc,
src->getBody()->getInstructions());
}
-// Move the loop body of ForStmt 'src' from 'src' to the start of dest's body.
-static inline void moveLoopBody(ForStmt *src, ForStmt *dest) {
+// Move the loop body of ForInst 'src' from 'src' to the start of dest's body.
+static inline void moveLoopBody(ForInst *src, ForInst *dest) {
moveLoopBody(src, dest, dest->getBody()->begin());
}
@@ -73,8 +73,8 @@ static inline void moveLoopBody(ForStmt *src, ForStmt *dest) {
/// depend on other dimensions. Bounds of each dimension can thus be treated
/// independently, and deriving the new bounds is much simpler and faster
/// than for the case of tiling arbitrary polyhedral shapes.
-static void constructTiledIndexSetHyperRect(ArrayRef<ForStmt *> origLoops,
- ArrayRef<ForStmt *> newLoops,
+static void constructTiledIndexSetHyperRect(ArrayRef<ForInst *> origLoops,
+ ArrayRef<ForInst *> newLoops,
ArrayRef<unsigned> tileSizes) {
assert(!origLoops.empty());
assert(origLoops.size() == tileSizes.size());
@@ -138,27 +138,27 @@ static void constructTiledIndexSetHyperRect(ArrayRef<ForStmt *> origLoops,
/// Tiles the specified band of perfectly nested loops creating tile-space loops
/// and intra-tile loops. A band is a contiguous set of loops.
// TODO(bondhugula): handle non hyper-rectangular spaces.
-UtilResult mlir::tileCodeGen(ArrayRef<ForStmt *> band,
+UtilResult mlir::tileCodeGen(ArrayRef<ForInst *> band,
ArrayRef<unsigned> tileSizes) {
assert(!band.empty());
assert(band.size() == tileSizes.size());
- // Check if the supplied for stmt's are all successively nested.
+ // Check if the supplied for inst's are all successively nested.
for (unsigned i = 1, e = band.size(); i < e; i++) {
- assert(band[i]->getParentStmt() == band[i - 1]);
+ assert(band[i]->getParentInst() == band[i - 1]);
}
auto origLoops = band;
- ForStmt *rootForStmt = origLoops[0];
- auto loc = rootForStmt->getLoc();
+ ForInst *rootForInst = origLoops[0];
+ auto loc = rootForInst->getLoc();
// Note that width is at least one since band isn't empty.
unsigned width = band.size();
- SmallVector<ForStmt *, 12> newLoops(2 * width);
- ForStmt *innermostPointLoop;
+ SmallVector<ForInst *, 12> newLoops(2 * width);
+ ForInst *innermostPointLoop;
// The outermost among the loops as we add more..
- auto *topLoop = rootForStmt;
+ auto *topLoop = rootForInst;
// Add intra-tile (or point) loops.
for (unsigned i = 0; i < width; i++) {
@@ -195,7 +195,7 @@ UtilResult mlir::tileCodeGen(ArrayRef<ForStmt *> band,
getIndexSet(band, &cst);
if (!cst.isHyperRectangular(0, width)) {
- rootForStmt->emitError("tiled code generation unimplemented for the"
+ rootForInst->emitError("tiled code generation unimplemented for the"
"non-hyperrectangular case");
return UtilResult::Failure;
}
@@ -207,7 +207,7 @@ UtilResult mlir::tileCodeGen(ArrayRef<ForStmt *> band,
}
// Erase the old loop nest.
- rootForStmt->erase();
+ rootForInst->erase();
return UtilResult::Success;
}
@@ -216,28 +216,28 @@ UtilResult mlir::tileCodeGen(ArrayRef<ForStmt *> band,
// a temporary placeholder to test the mechanics of tiled code generation.
// Returns all maximal outermost perfect loop nests to tile.
static void getTileableBands(Function *f,
- std::vector<SmallVector<ForStmt *, 6>> *bands) {
- // Get maximal perfect nest of 'for' stmts starting from root (inclusive).
- auto getMaximalPerfectLoopNest = [&](ForStmt *root) {
- SmallVector<ForStmt *, 6> band;
- ForStmt *currStmt = root;
+ std::vector<SmallVector<ForInst *, 6>> *bands) {
+ // Get maximal perfect nest of 'for' insts starting from root (inclusive).
+ auto getMaximalPerfectLoopNest = [&](ForInst *root) {
+ SmallVector<ForInst *, 6> band;
+ ForInst *currInst = root;
do {
- band.push_back(currStmt);
- } while (currStmt->getBody()->getInstructions().size() == 1 &&
- (currStmt = dyn_cast<ForStmt>(&*currStmt->getBody()->begin())));
+ band.push_back(currInst);
+ } while (currInst->getBody()->getInstructions().size() == 1 &&
+ (currInst = dyn_cast<ForInst>(&*currInst->getBody()->begin())));
bands->push_back(band);
};
- for (auto &stmt : *f->getBody()) {
- auto *forStmt = dyn_cast<ForStmt>(&stmt);
- if (!forStmt)
+ for (auto &inst : *f->getBody()) {
+ auto *forInst = dyn_cast<ForInst>(&inst);
+ if (!forInst)
continue;
- getMaximalPerfectLoopNest(forStmt);
+ getMaximalPerfectLoopNest(forInst);
}
}
PassResult LoopTiling::runOnMLFunction(Function *f) {
- std::vector<SmallVector<ForStmt *, 6>> bands;
+ std::vector<SmallVector<ForInst *, 6>> bands;
getTileableBands(f, &bands);
// Temporary tile sizes.
diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp
index 15ea0f841cc..69431bf6349 100644
--- a/mlir/lib/Transforms/LoopUnroll.cpp
+++ b/mlir/lib/Transforms/LoopUnroll.cpp
@@ -26,7 +26,7 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/StmtVisitor.h"
+#include "mlir/IR/InstVisitor.h"
#include "mlir/Pass.h"
#include "mlir/Transforms/LoopUtils.h"
#include "llvm/ADT/DenseMap.h"
@@ -62,18 +62,18 @@ struct LoopUnroll : public FunctionPass {
const Optional<bool> unrollFull;
// Callback to obtain unroll factors; if this has a callable target, takes
// precedence over command-line argument or passed argument.
- const std::function<unsigned(const ForStmt &)> getUnrollFactor;
+ const std::function<unsigned(const ForInst &)> getUnrollFactor;
explicit LoopUnroll(
Optional<unsigned> unrollFactor = None, Optional<bool> unrollFull = None,
- const std::function<unsigned(const ForStmt &)> &getUnrollFactor = nullptr)
+ const std::function<unsigned(const ForInst &)> &getUnrollFactor = nullptr)
: FunctionPass(&LoopUnroll::passID), unrollFactor(unrollFactor),
unrollFull(unrollFull), getUnrollFactor(getUnrollFactor) {}
PassResult runOnMLFunction(Function *f) override;
- /// Unroll this for stmt. Returns false if nothing was done.
- bool runOnForStmt(ForStmt *forStmt);
+ /// Unroll this for inst. Returns false if nothing was done.
+ bool runOnForInst(ForInst *forInst);
static const unsigned kDefaultUnrollFactor = 4;
@@ -85,13 +85,13 @@ char LoopUnroll::passID = 0;
PassResult LoopUnroll::runOnMLFunction(Function *f) {
// Gathers all innermost loops through a post order pruned walk.
- class InnermostLoopGatherer : public StmtWalker<InnermostLoopGatherer, bool> {
+ class InnermostLoopGatherer : public InstWalker<InnermostLoopGatherer, bool> {
public:
// Store innermost loops as we walk.
- std::vector<ForStmt *> loops;
+ std::vector<ForInst *> loops;
// This method specialized to encode custom return logic.
- using InstListType = llvm::iplist<Statement>;
+ using InstListType = llvm::iplist<Instruction>;
bool walkPostOrder(InstListType::iterator Start,
InstListType::iterator End) {
bool hasInnerLoops = false;
@@ -103,43 +103,43 @@ PassResult LoopUnroll::runOnMLFunction(Function *f) {
return hasInnerLoops;
}
- bool walkForStmtPostOrder(ForStmt *forStmt) {
+ bool walkForInstPostOrder(ForInst *forInst) {
bool hasInnerLoops =
- walkPostOrder(forStmt->getBody()->begin(), forStmt->getBody()->end());
+ walkPostOrder(forInst->getBody()->begin(), forInst->getBody()->end());
if (!hasInnerLoops)
- loops.push_back(forStmt);
+ loops.push_back(forInst);
return true;
}
- bool walkIfStmtPostOrder(IfStmt *ifStmt) {
+ bool walkIfInstPostOrder(IfInst *ifInst) {
bool hasInnerLoops =
- walkPostOrder(ifStmt->getThen()->begin(), ifStmt->getThen()->end());
- if (ifStmt->hasElse())
+ walkPostOrder(ifInst->getThen()->begin(), ifInst->getThen()->end());
+ if (ifInst->hasElse())
hasInnerLoops |=
- walkPostOrder(ifStmt->getElse()->begin(), ifStmt->getElse()->end());
+ walkPostOrder(ifInst->getElse()->begin(), ifInst->getElse()->end());
return hasInnerLoops;
}
- bool visitOperationInst(OperationInst *opStmt) { return false; }
+ bool visitOperationInst(OperationInst *opInst) { return false; }
// FIXME: can't use base class method for this because that in turn would
// need to use the derived class method above. CRTP doesn't allow it, and
// the compiler error resulting from it is also misleading.
- using StmtWalker<InnermostLoopGatherer, bool>::walkPostOrder;
+ using InstWalker<InnermostLoopGatherer, bool>::walkPostOrder;
};
// Gathers all loops with trip count <= minTripCount.
- class ShortLoopGatherer : public StmtWalker<ShortLoopGatherer> {
+ class ShortLoopGatherer : public InstWalker<ShortLoopGatherer> {
public:
// Store short loops as we walk.
- std::vector<ForStmt *> loops;
+ std::vector<ForInst *> loops;
const unsigned minTripCount;
ShortLoopGatherer(unsigned minTripCount) : minTripCount(minTripCount) {}
- void visitForStmt(ForStmt *forStmt) {
- Optional<uint64_t> tripCount = getConstantTripCount(*forStmt);
+ void visitForInst(ForInst *forInst) {
+ Optional<uint64_t> tripCount = getConstantTripCount(*forInst);
if (tripCount.hasValue() && tripCount.getValue() <= minTripCount)
- loops.push_back(forStmt);
+ loops.push_back(forInst);
}
};
@@ -151,8 +151,8 @@ PassResult LoopUnroll::runOnMLFunction(Function *f) {
// ones).
slg.walkPostOrder(f);
auto &loops = slg.loops;
- for (auto *forStmt : loops)
- loopUnrollFull(forStmt);
+ for (auto *forInst : loops)
+ loopUnrollFull(forInst);
return success();
}
@@ -167,8 +167,8 @@ PassResult LoopUnroll::runOnMLFunction(Function *f) {
if (loops.empty())
break;
bool unrolled = false;
- for (auto *forStmt : loops)
- unrolled |= runOnForStmt(forStmt);
+ for (auto *forInst : loops)
+ unrolled |= runOnForInst(forInst);
if (!unrolled)
// Break out if nothing was unrolled.
break;
@@ -176,31 +176,31 @@ PassResult LoopUnroll::runOnMLFunction(Function *f) {
return success();
}
-/// Unrolls a 'for' stmt. Returns true if the loop was unrolled, false
+/// Unrolls a 'for' inst. Returns true if the loop was unrolled, false
/// otherwise. The default unroll factor is 4.
-bool LoopUnroll::runOnForStmt(ForStmt *forStmt) {
+bool LoopUnroll::runOnForInst(ForInst *forInst) {
// Use the function callback if one was provided.
if (getUnrollFactor) {
- return loopUnrollByFactor(forStmt, getUnrollFactor(*forStmt));
+ return loopUnrollByFactor(forInst, getUnrollFactor(*forInst));
}
// Unroll by the factor passed, if any.
if (unrollFactor.hasValue())
- return loopUnrollByFactor(forStmt, unrollFactor.getValue());
+ return loopUnrollByFactor(forInst, unrollFactor.getValue());
// Unroll by the command line factor if one was specified.
if (clUnrollFactor.getNumOccurrences() > 0)
- return loopUnrollByFactor(forStmt, clUnrollFactor);
+ return loopUnrollByFactor(forInst, clUnrollFactor);
// Unroll completely if full loop unroll was specified.
if (clUnrollFull.getNumOccurrences() > 0 ||
(unrollFull.hasValue() && unrollFull.getValue()))
- return loopUnrollFull(forStmt);
+ return loopUnrollFull(forInst);
// Unroll by four otherwise.
- return loopUnrollByFactor(forStmt, kDefaultUnrollFactor);
+ return loopUnrollByFactor(forInst, kDefaultUnrollFactor);
}
FunctionPass *mlir::createLoopUnrollPass(
int unrollFactor, int unrollFull,
- const std::function<unsigned(const ForStmt &)> &getUnrollFactor) {
+ const std::function<unsigned(const ForInst &)> &getUnrollFactor) {
return new LoopUnroll(
unrollFactor == -1 ? None : Optional<unsigned>(unrollFactor),
unrollFull == -1 ? None : Optional<bool>(unrollFull), getUnrollFactor);
diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp
index 60e8d154f98..f59659cf234 100644
--- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp
+++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp
@@ -40,7 +40,7 @@
// S6(i+1);
//
// Note: 'if/else' blocks are not jammed. So, if there are loops inside if
-// stmt's, bodies of those loops will not be jammed.
+// inst's, bodies of those loops will not be jammed.
//===----------------------------------------------------------------------===//
#include "mlir/Transforms/Passes.h"
@@ -49,7 +49,7 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/StmtVisitor.h"
+#include "mlir/IR/InstVisitor.h"
#include "mlir/Pass.h"
#include "mlir/Transforms/LoopUtils.h"
#include "llvm/ADT/DenseMap.h"
@@ -75,7 +75,7 @@ struct LoopUnrollAndJam : public FunctionPass {
unrollJamFactor(unrollJamFactor) {}
PassResult runOnMLFunction(Function *f) override;
- bool runOnForStmt(ForStmt *forStmt);
+ bool runOnForInst(ForInst *forInst);
static char passID;
};
@@ -90,79 +90,79 @@ FunctionPass *mlir::createLoopUnrollAndJamPass(int unrollJamFactor) {
PassResult LoopUnrollAndJam::runOnMLFunction(Function *f) {
// Currently, just the outermost loop from the first loop nest is
- // unroll-and-jammed by this pass. However, runOnForStmt can be called on any
- // for Stmt.
- auto *forStmt = dyn_cast<ForStmt>(f->getBody()->begin());
- if (!forStmt)
+ // unroll-and-jammed by this pass. However, runOnForInst can be called on any
+ // for Inst.
+ auto *forInst = dyn_cast<ForInst>(f->getBody()->begin());
+ if (!forInst)
return success();
- runOnForStmt(forStmt);
+ runOnForInst(forInst);
return success();
}
-/// Unroll and jam a 'for' stmt. Default unroll jam factor is
+/// Unroll and jam a 'for' inst. Default unroll jam factor is
/// kDefaultUnrollJamFactor. Return false if nothing was done.
-bool LoopUnrollAndJam::runOnForStmt(ForStmt *forStmt) {
+bool LoopUnrollAndJam::runOnForInst(ForInst *forInst) {
// Unroll and jam by the factor that was passed if any.
if (unrollJamFactor.hasValue())
- return loopUnrollJamByFactor(forStmt, unrollJamFactor.getValue());
+ return loopUnrollJamByFactor(forInst, unrollJamFactor.getValue());
// Otherwise, unroll jam by the command-line factor if one was specified.
if (clUnrollJamFactor.getNumOccurrences() > 0)
- return loopUnrollJamByFactor(forStmt, clUnrollJamFactor);
+ return loopUnrollJamByFactor(forInst, clUnrollJamFactor);
// Unroll and jam by four otherwise.
- return loopUnrollJamByFactor(forStmt, kDefaultUnrollJamFactor);
+ return loopUnrollJamByFactor(forInst, kDefaultUnrollJamFactor);
}
-bool mlir::loopUnrollJamUpToFactor(ForStmt *forStmt, uint64_t unrollJamFactor) {
- Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forStmt);
+bool mlir::loopUnrollJamUpToFactor(ForInst *forInst, uint64_t unrollJamFactor) {
+ Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forInst);
if (mayBeConstantTripCount.hasValue() &&
mayBeConstantTripCount.getValue() < unrollJamFactor)
- return loopUnrollJamByFactor(forStmt, mayBeConstantTripCount.getValue());
- return loopUnrollJamByFactor(forStmt, unrollJamFactor);
+ return loopUnrollJamByFactor(forInst, mayBeConstantTripCount.getValue());
+ return loopUnrollJamByFactor(forInst, unrollJamFactor);
}
/// Unrolls and jams this loop by the specified factor.
-bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) {
- // Gathers all maximal sub-blocks of statements that do not themselves include
- // a for stmt (a statement could have a descendant for stmt though in its
- // tree).
- class JamBlockGatherer : public StmtWalker<JamBlockGatherer> {
+bool mlir::loopUnrollJamByFactor(ForInst *forInst, uint64_t unrollJamFactor) {
+ // Gathers all maximal sub-blocks of instructions that do not themselves
+ // include a for inst (a instruction could have a descendant for inst though
+ // in its tree).
+ class JamBlockGatherer : public InstWalker<JamBlockGatherer> {
public:
- using InstListType = llvm::iplist<Statement>;
+ using InstListType = llvm::iplist<Instruction>;
- // Store iterators to the first and last stmt of each sub-block found.
+ // Store iterators to the first and last inst of each sub-block found.
std::vector<std::pair<Block::iterator, Block::iterator>> subBlocks;
// This is a linear time walk.
void walk(InstListType::iterator Start, InstListType::iterator End) {
for (auto it = Start; it != End;) {
auto subBlockStart = it;
- while (it != End && !isa<ForStmt>(it))
+ while (it != End && !isa<ForInst>(it))
++it;
if (it != subBlockStart)
subBlocks.push_back({subBlockStart, std::prev(it)});
- // Process all for stmts that appear next.
- while (it != End && isa<ForStmt>(it))
- walkForStmt(cast<ForStmt>(it++));
+ // Process all for insts that appear next.
+ while (it != End && isa<ForInst>(it))
+ walkForInst(cast<ForInst>(it++));
}
}
};
assert(unrollJamFactor >= 1 && "unroll jam factor should be >= 1");
- if (unrollJamFactor == 1 || forStmt->getBody()->empty())
+ if (unrollJamFactor == 1 || forInst->getBody()->empty())
return false;
- Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forStmt);
+ Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forInst);
if (!mayBeConstantTripCount.hasValue() &&
- getLargestDivisorOfTripCount(*forStmt) % unrollJamFactor != 0)
+ getLargestDivisorOfTripCount(*forInst) % unrollJamFactor != 0)
return false;
- auto lbMap = forStmt->getLowerBoundMap();
- auto ubMap = forStmt->getUpperBoundMap();
+ auto lbMap = forInst->getLowerBoundMap();
+ auto ubMap = forInst->getUpperBoundMap();
// Loops with max/min expressions won't be unrolled here (the output can't be
// expressed as a Function in the general case). However, the right way to
@@ -173,7 +173,7 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) {
// Same operand list for lower and upper bound for now.
// TODO(bondhugula): handle bounds with different sets of operands.
- if (!forStmt->matchingBoundOperandList())
+ if (!forInst->matchingBoundOperandList())
return false;
// If the trip count is lower than the unroll jam factor, no unroll jam.
@@ -184,7 +184,7 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) {
// Gather all sub-blocks to jam upon the loop being unrolled.
JamBlockGatherer jbg;
- jbg.walkForStmt(forStmt);
+ jbg.walkForInst(forInst);
auto &subBlocks = jbg.subBlocks;
// Generate the cleanup loop if trip count isn't a multiple of
@@ -192,24 +192,24 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) {
if (mayBeConstantTripCount.hasValue() &&
mayBeConstantTripCount.getValue() % unrollJamFactor != 0) {
DenseMap<const Value *, Value *> operandMap;
- // Insert the cleanup loop right after 'forStmt'.
- FuncBuilder builder(forStmt->getBlock(),
- std::next(Block::iterator(forStmt)));
- auto *cleanupForStmt = cast<ForStmt>(builder.clone(*forStmt, operandMap));
- cleanupForStmt->setLowerBoundMap(
- getCleanupLoopLowerBound(*forStmt, unrollJamFactor, &builder));
+ // Insert the cleanup loop right after 'forInst'.
+ FuncBuilder builder(forInst->getBlock(),
+ std::next(Block::iterator(forInst)));
+ auto *cleanupForInst = cast<ForInst>(builder.clone(*forInst, operandMap));
+ cleanupForInst->setLowerBoundMap(
+ getCleanupLoopLowerBound(*forInst, unrollJamFactor, &builder));
// The upper bound needs to be adjusted.
- forStmt->setUpperBoundMap(
- getUnrolledLoopUpperBound(*forStmt, unrollJamFactor, &builder));
+ forInst->setUpperBoundMap(
+ getUnrolledLoopUpperBound(*forInst, unrollJamFactor, &builder));
// Promote the loop body up if this has turned into a single iteration loop.
- promoteIfSingleIteration(cleanupForStmt);
+ promoteIfSingleIteration(cleanupForInst);
}
// Scale the step of loop being unroll-jammed by the unroll-jam factor.
- int64_t step = forStmt->getStep();
- forStmt->setStep(step * unrollJamFactor);
+ int64_t step = forInst->getStep();
+ forInst->setStep(step * unrollJamFactor);
for (auto &subBlock : subBlocks) {
// Builder to insert unroll-jammed bodies. Insert right at the end of
@@ -222,14 +222,14 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) {
// If the induction variable is used, create a remapping to the value for
// this unrolled instance.
- if (!forStmt->use_empty()) {
+ if (!forInst->use_empty()) {
// iv' = iv + i, i = 1 to unrollJamFactor-1.
auto d0 = builder.getAffineDimExpr(0);
auto bumpMap = builder.getAffineMap(1, 0, {d0 + i * step}, {});
auto *ivUnroll =
- builder.create<AffineApplyOp>(forStmt->getLoc(), bumpMap, forStmt)
+ builder.create<AffineApplyOp>(forInst->getLoc(), bumpMap, forInst)
->getResult(0);
- operandMapping[forStmt] = ivUnroll;
+ operandMapping[forInst] = ivUnroll;
}
// Clone the sub-block being unroll-jammed.
for (auto it = subBlock.first; it != std::next(subBlock.second); ++it) {
@@ -239,7 +239,7 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) {
}
// Promote the loop body up if this has turned into a single iteration loop.
- promoteIfSingleIteration(forStmt);
+ promoteIfSingleIteration(forInst);
return true;
}
diff --git a/mlir/lib/Transforms/LowerVectorTransfers.cpp b/mlir/lib/Transforms/LowerVectorTransfers.cpp
index 51577009abb..bcb2abf11dd 100644
--- a/mlir/lib/Transforms/LowerVectorTransfers.cpp
+++ b/mlir/lib/Transforms/LowerVectorTransfers.cpp
@@ -110,7 +110,7 @@ static void rewriteAsLoops(VectorTransferOpTy *transfer,
// Get the ML function builder.
// We need access to the Function builder stored internally in the
// MLFunctionLoweringRewriter general rewriting API does not provide
- // ML-specific functions (ForStmt and Block manipulation). While we could
+ // ML-specific functions (ForInst and Block manipulation). While we could
// forward them or define a whole rewriting chain based on MLFunctionBuilder
// instead of Builer, the code for it would be duplicate boilerplate. As we
// go towards unifying ML and CFG functions, this separation will disappear.
@@ -137,13 +137,13 @@ static void rewriteAsLoops(VectorTransferOpTy *transfer,
// memory.
// TODO(ntv): Handle broadcast / slice properly.
auto permutationMap = transfer->getPermutationMap();
- SetVector<ForStmt *> loops;
+ SetVector<ForInst *> loops;
SmallVector<Value *, 8> accessIndices(transfer->getIndices());
for (auto it : llvm::enumerate(transfer->getVectorType().getShape())) {
auto composed = composeWithUnboundedMap(
getAffineDimExpr(it.index(), b.getContext()), permutationMap);
- auto *forStmt = b.createFor(transfer->getLoc(), 0, it.value());
- loops.insert(forStmt);
+ auto *forInst = b.createFor(transfer->getLoc(), 0, it.value());
+ loops.insert(forInst);
// Setting the insertion point to the innermost loop achieves nesting.
b.setInsertionPointToStart(loops.back()->getBody());
if (composed == getAffineConstantExpr(0, b.getContext())) {
@@ -196,7 +196,7 @@ static void rewriteAsLoops(VectorTransferOpTy *transfer,
b.setInsertionPoint(transfer->getInstruction());
b.create<DeallocOp>(transfer->getLoc(), tmpScalarAlloc);
- // 7. It is now safe to erase the statement.
+ // 7. It is now safe to erase the instruction.
rewriter->replaceOp(transfer->getInstruction(), newResults);
}
@@ -213,7 +213,7 @@ public:
return matchFailure();
}
- void rewriteOpStmt(OperationInst *op,
+ void rewriteOpInst(OperationInst *op,
MLFuncGlobalLoweringState *funcWiseState,
std::unique_ptr<PatternState> opState,
MLFuncLoweringRewriter *rewriter) const override {
diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp
index a30e8164760..37f0f571a0f 100644
--- a/mlir/lib/Transforms/MaterializeVectors.cpp
+++ b/mlir/lib/Transforms/MaterializeVectors.cpp
@@ -73,7 +73,7 @@
/// Implementation details
/// ======================
/// The current decisions made by the super-vectorization pass guarantee that
-/// use-def chains do not escape an enclosing vectorized ForStmt. In other
+/// use-def chains do not escape an enclosing vectorized ForInst. In other
/// words, this pass operates on a scoped program slice. Furthermore, since we
/// do not vectorize in the presence of conditionals for now, sliced chains are
/// guaranteed not to escape the innermost scope, which has to be either the top
@@ -247,7 +247,7 @@ static SmallVector<unsigned, 8> delinearize(unsigned linearIndex,
}
static OperationInst *
-instantiate(FuncBuilder *b, OperationInst *opStmt, VectorType hwVectorType,
+instantiate(FuncBuilder *b, OperationInst *opInst, VectorType hwVectorType,
DenseMap<const Value *, Value *> *substitutionsMap);
/// Not all Values belong to a program slice scoped within the immediately
@@ -263,10 +263,10 @@ static Value *substitute(Value *v, VectorType hwVectorType,
DenseMap<const Value *, Value *> *substitutionsMap) {
auto it = substitutionsMap->find(v);
if (it == substitutionsMap->end()) {
- auto *opStmt = v->getDefiningInst();
- if (opStmt->isa<ConstantOp>()) {
- FuncBuilder b(opStmt);
- auto *inst = instantiate(&b, opStmt, hwVectorType, substitutionsMap);
+ auto *opInst = v->getDefiningInst();
+ if (opInst->isa<ConstantOp>()) {
+ FuncBuilder b(opInst);
+ auto *inst = instantiate(&b, opInst, hwVectorType, substitutionsMap);
auto res =
substitutionsMap->insert(std::make_pair(v, inst->getResult(0)));
assert(res.second && "Insertion failed");
@@ -285,7 +285,7 @@ static Value *substitute(Value *v, VectorType hwVectorType,
///
/// The general problem this pass solves is as follows:
/// Assume a vector_transfer operation at the super-vector granularity that has
-/// `l` enclosing loops (ForStmt). Assume the vector transfer operation operates
+/// `l` enclosing loops (ForInst). Assume the vector transfer operation operates
/// on a MemRef of rank `r`, a super-vector of rank `s` and a hardware vector of
/// rank `h`.
/// For the purpose of illustration assume l==4, r==3, s==2, h==1 and that the
@@ -347,7 +347,7 @@ reindexAffineIndices(FuncBuilder *b, VectorType hwVectorType,
SmallVector<AffineExpr, 8> affineExprs;
// TODO(ntv): support a concrete map and composition.
unsigned i = 0;
- // The first numMemRefIndices correspond to ForStmt that have not been
+ // The first numMemRefIndices correspond to ForInst that have not been
// vectorized, the transformation is the identity on those.
for (i = 0; i < numMemRefIndices; ++i) {
auto d_i = b->getAffineDimExpr(i);
@@ -384,9 +384,9 @@ reindexAffineIndices(FuncBuilder *b, VectorType hwVectorType,
/// - constant splat is replaced by constant splat of `hwVectorType`.
/// TODO(ntv): add more substitutions on a per-need basis.
static SmallVector<NamedAttribute, 1>
-materializeAttributes(OperationInst *opStmt, VectorType hwVectorType) {
+materializeAttributes(OperationInst *opInst, VectorType hwVectorType) {
SmallVector<NamedAttribute, 1> res;
- for (auto a : opStmt->getAttrs()) {
+ for (auto a : opInst->getAttrs()) {
if (auto splat = a.second.dyn_cast<SplatElementsAttr>()) {
auto attr = SplatElementsAttr::get(hwVectorType, splat.getValue());
res.push_back(NamedAttribute(a.first, attr));
@@ -397,7 +397,7 @@ materializeAttributes(OperationInst *opStmt, VectorType hwVectorType) {
return res;
}
-/// Creates an instantiated version of `opStmt`.
+/// Creates an instantiated version of `opInst`.
/// Ops other than VectorTransferReadOp/VectorTransferWriteOp require no
/// affine reindexing. Just substitute their Value operands and be done. For
/// this case the actual instance is irrelevant. Just use the values in
@@ -405,11 +405,11 @@ materializeAttributes(OperationInst *opStmt, VectorType hwVectorType) {
///
/// If the underlying substitution fails, this fails too and returns nullptr.
static OperationInst *
-instantiate(FuncBuilder *b, OperationInst *opStmt, VectorType hwVectorType,
+instantiate(FuncBuilder *b, OperationInst *opInst, VectorType hwVectorType,
DenseMap<const Value *, Value *> *substitutionsMap) {
- assert(!opStmt->isa<VectorTransferReadOp>() &&
+ assert(!opInst->isa<VectorTransferReadOp>() &&
"Should call the function specialized for VectorTransferReadOp");
- assert(!opStmt->isa<VectorTransferWriteOp>() &&
+ assert(!opInst->isa<VectorTransferWriteOp>() &&
"Should call the function specialized for VectorTransferWriteOp");
bool fail = false;
auto operands = map(
@@ -419,14 +419,14 @@ instantiate(FuncBuilder *b, OperationInst *opStmt, VectorType hwVectorType,
fail |= !res;
return res;
},
- opStmt->getOperands());
+ opInst->getOperands());
if (fail)
return nullptr;
- auto attrs = materializeAttributes(opStmt, hwVectorType);
+ auto attrs = materializeAttributes(opInst, hwVectorType);
- OperationState state(b->getContext(), opStmt->getLoc(),
- opStmt->getName().getStringRef(), operands,
+ OperationState state(b->getContext(), opInst->getLoc(),
+ opInst->getName().getStringRef(), operands,
{hwVectorType}, attrs);
return b->createOperation(state);
}
@@ -511,11 +511,11 @@ instantiate(FuncBuilder *b, VectorTransferWriteOp *write,
return cloned->getInstruction();
}
-/// Returns `true` if stmt instance is properly cloned and inserted, false
+/// Returns `true` if inst instance is properly cloned and inserted, false
/// otherwise.
/// The multi-dimensional `hwVectorInstance` belongs to the shapeRatio of
/// super-vector type to hw vector type.
-/// A cloned instance of `stmt` is formed as follows:
+/// A cloned instance of `inst` is formed as follows:
/// 1. vector_transfer_read: the return `superVectorType` is replaced by
/// `hwVectorType`. Additionally, affine indices are reindexed with
/// `reindexAffineIndices` using `hwVectorInstance` and vector type
@@ -532,24 +532,24 @@ instantiate(FuncBuilder *b, VectorTransferWriteOp *write,
/// possible.
///
/// Returns true on failure.
-static bool instantiateMaterialization(Statement *stmt,
+static bool instantiateMaterialization(Instruction *inst,
MaterializationState *state) {
- LLVM_DEBUG(dbgs() << "\ninstantiate: " << *stmt);
+ LLVM_DEBUG(dbgs() << "\ninstantiate: " << *inst);
- if (isa<ForStmt>(stmt))
- return stmt->emitError("NYI path ForStmt");
+ if (isa<ForInst>(inst))
+ return inst->emitError("NYI path ForInst");
- if (isa<IfStmt>(stmt))
- return stmt->emitError("NYI path IfStmt");
+ if (isa<IfInst>(inst))
+ return inst->emitError("NYI path IfInst");
// Create a builder here for unroll-and-jam effects.
- FuncBuilder b(stmt);
- auto *opStmt = cast<OperationInst>(stmt);
- if (auto write = opStmt->dyn_cast<VectorTransferWriteOp>()) {
+ FuncBuilder b(inst);
+ auto *opInst = cast<OperationInst>(inst);
+ if (auto write = opInst->dyn_cast<VectorTransferWriteOp>()) {
instantiate(&b, write, state->hwVectorType, state->hwVectorInstance,
state->substitutionsMap);
return false;
- } else if (auto read = opStmt->dyn_cast<VectorTransferReadOp>()) {
+ } else if (auto read = opInst->dyn_cast<VectorTransferReadOp>()) {
auto *clone = instantiate(&b, read, state->hwVectorType,
state->hwVectorInstance, state->substitutionsMap);
state->substitutionsMap->insert(
@@ -559,17 +559,17 @@ static bool instantiateMaterialization(Statement *stmt,
// The only op with 0 results reaching this point must, by construction, be
// VectorTransferWriteOps and have been caught above. Ops with >= 2 results
// are not yet supported. So just support 1 result.
- if (opStmt->getNumResults() != 1)
- return stmt->emitError("NYI: ops with != 1 results");
- if (opStmt->getResult(0)->getType() != state->superVectorType)
- return stmt->emitError("Op does not return a supervector.");
+ if (opInst->getNumResults() != 1)
+ return inst->emitError("NYI: ops with != 1 results");
+ if (opInst->getResult(0)->getType() != state->superVectorType)
+ return inst->emitError("Op does not return a supervector.");
auto *clone =
- instantiate(&b, opStmt, state->hwVectorType, state->substitutionsMap);
+ instantiate(&b, opInst, state->hwVectorType, state->substitutionsMap);
if (!clone) {
return true;
}
state->substitutionsMap->insert(
- std::make_pair(opStmt->getResult(0), clone->getResult(0)));
+ std::make_pair(opInst->getResult(0), clone->getResult(0)));
return false;
}
@@ -595,7 +595,7 @@ static bool instantiateMaterialization(Statement *stmt,
/// TODO(ntv): full loops + materialized allocs.
/// TODO(ntv): partial unrolling + materialized allocs.
static bool emitSlice(MaterializationState *state,
- SetVector<Statement *> *slice) {
+ SetVector<Instruction *> *slice) {
auto ratio = shapeRatio(state->superVectorType, state->hwVectorType);
assert(ratio.hasValue() &&
"ratio of super-vector to HW-vector shape is not integral");
@@ -610,10 +610,10 @@ static bool emitSlice(MaterializationState *state,
DenseMap<const Value *, Value *> substitutionMap;
scopedState.substitutionsMap = &substitutionMap;
// slice are topologically sorted, we can just clone them in order.
- for (auto *stmt : *slice) {
- auto fail = instantiateMaterialization(stmt, &scopedState);
+ for (auto *inst : *slice) {
+ auto fail = instantiateMaterialization(inst, &scopedState);
if (fail) {
- stmt->emitError("Unhandled super-vector materialization failure");
+ inst->emitError("Unhandled super-vector materialization failure");
return true;
}
}
@@ -636,7 +636,7 @@ static bool emitSlice(MaterializationState *state,
/// Materializes super-vector types into concrete hw vector types as follows:
/// 1. start from super-vector terminators (current vector_transfer_write
/// ops);
-/// 2. collect all the statements that can be reached by transitive use-defs
+/// 2. collect all the instructions that can be reached by transitive use-defs
/// chains;
/// 3. get the superVectorType for this particular terminator and the
/// corresponding hardware vector type (for now limited to F32)
@@ -647,13 +647,13 @@ static bool emitSlice(MaterializationState *state,
/// Notes
/// =====
/// The `slice` is sorted in topological order by construction.
-/// Additionally, this set is limited to statements in the same lexical scope
+/// Additionally, this set is limited to instructions in the same lexical scope
/// because we currently disallow vectorization of defs that come from another
/// scope.
static bool materialize(Function *f,
const SetVector<OperationInst *> &terminators,
MaterializationState *state) {
- DenseSet<Statement *> seen;
+ DenseSet<Instruction *> seen;
for (auto *term : terminators) {
// Short-circuit test, a given terminator may have been reached by some
// other previous transitive use-def chains.
@@ -668,16 +668,16 @@ static bool materialize(Function *f,
// current enclosing scope of the terminator. See the top of the function
// Note for the justification of this restriction.
// TODO(ntv): relax scoping constraints.
- auto *enclosingScope = term->getParentStmt();
- auto keepIfInSameScope = [enclosingScope](Statement *stmt) {
- assert(stmt && "NULL stmt");
+ auto *enclosingScope = term->getParentInst();
+ auto keepIfInSameScope = [enclosingScope](Instruction *inst) {
+ assert(inst && "NULL inst");
if (!enclosingScope) {
// by construction, everyone is always under the top scope (null scope).
return true;
}
- return properlyDominates(*enclosingScope, *stmt);
+ return properlyDominates(*enclosingScope, *inst);
};
- SetVector<Statement *> slice =
+ SetVector<Instruction *> slice =
getSlice(term, keepIfInSameScope, keepIfInSameScope);
assert(!slice.empty());
@@ -722,12 +722,12 @@ PassResult MaterializeVectorsPass::runOnMLFunction(Function *f) {
// Capture terminators; i.e. vector_transfer_write ops involving a strict
// super-vector of subVectorType.
- auto filter = [subVectorType](const Statement &stmt) {
- const auto &opStmt = cast<OperationInst>(stmt);
- if (!opStmt.isa<VectorTransferWriteOp>()) {
+ auto filter = [subVectorType](const Instruction &inst) {
+ const auto &opInst = cast<OperationInst>(inst);
+ if (!opInst.isa<VectorTransferWriteOp>()) {
return false;
}
- return matcher::operatesOnStrictSuperVectors(opStmt, subVectorType);
+ return matcher::operatesOnStrictSuperVectors(opInst, subVectorType);
};
auto pat = Op(filter);
auto matches = pat.match(f);
diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp
index c8a6ced4ed1..debaac3a33c 100644
--- a/mlir/lib/Transforms/PipelineDataTransfer.cpp
+++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp
@@ -25,7 +25,7 @@
#include "mlir/Analysis/LoopAnalysis.h"
#include "mlir/Analysis/Utils.h"
#include "mlir/IR/Builders.h"
-#include "mlir/IR/StmtVisitor.h"
+#include "mlir/IR/InstVisitor.h"
#include "mlir/Pass.h"
#include "mlir/StandardOps/StandardOps.h"
#include "mlir/Transforms/LoopUtils.h"
@@ -39,14 +39,14 @@ using namespace mlir;
namespace {
struct PipelineDataTransfer : public FunctionPass,
- StmtWalker<PipelineDataTransfer> {
+ InstWalker<PipelineDataTransfer> {
PipelineDataTransfer() : FunctionPass(&PipelineDataTransfer::passID) {}
PassResult runOnMLFunction(Function *f) override;
- PassResult runOnForStmt(ForStmt *forStmt);
+ PassResult runOnForInst(ForInst *forInst);
- // Collect all 'for' statements.
- void visitForStmt(ForStmt *forStmt) { forStmts.push_back(forStmt); }
- std::vector<ForStmt *> forStmts;
+ // Collect all 'for' instructions.
+ void visitForInst(ForInst *forInst) { forInsts.push_back(forInst); }
+ std::vector<ForInst *> forInsts;
static char passID;
};
@@ -61,26 +61,26 @@ FunctionPass *mlir::createPipelineDataTransferPass() {
return new PipelineDataTransfer();
}
-// Returns the position of the tag memref operand given a DMA statement.
+// Returns the position of the tag memref operand given a DMA instruction.
// Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are
// added. TODO(b/117228571)
-static unsigned getTagMemRefPos(const OperationInst &dmaStmt) {
- assert(dmaStmt.isa<DmaStartOp>() || dmaStmt.isa<DmaWaitOp>());
- if (dmaStmt.isa<DmaStartOp>()) {
+static unsigned getTagMemRefPos(const OperationInst &dmaInst) {
+ assert(dmaInst.isa<DmaStartOp>() || dmaInst.isa<DmaWaitOp>());
+ if (dmaInst.isa<DmaStartOp>()) {
// Second to last operand.
- return dmaStmt.getNumOperands() - 2;
+ return dmaInst.getNumOperands() - 2;
}
- // First operand for a dma finish statement.
+ // First operand for a dma finish instruction.
return 0;
}
-/// Doubles the buffer of the supplied memref on the specified 'for' statement
+/// Doubles the buffer of the supplied memref on the specified 'for' instruction
/// by adding a leading dimension of size two to the memref. Replaces all uses
/// of the old memref by the new one while indexing the newly added dimension by
-/// the loop IV of the specified 'for' statement modulo 2. Returns false if such
-/// a replacement cannot be performed.
-static bool doubleBuffer(Value *oldMemRef, ForStmt *forStmt) {
- auto *forBody = forStmt->getBody();
+/// the loop IV of the specified 'for' instruction modulo 2. Returns false if
+/// such a replacement cannot be performed.
+static bool doubleBuffer(Value *oldMemRef, ForInst *forInst) {
+ auto *forBody = forInst->getBody();
FuncBuilder bInner(forBody, forBody->begin());
bInner.setInsertionPoint(forBody, forBody->begin());
@@ -101,33 +101,33 @@ static bool doubleBuffer(Value *oldMemRef, ForStmt *forStmt) {
auto newMemRefType = doubleShape(oldMemRefType);
// Put together alloc operands for the dynamic dimensions of the memref.
- FuncBuilder bOuter(forStmt);
+ FuncBuilder bOuter(forInst);
SmallVector<Value *, 4> allocOperands;
unsigned dynamicDimCount = 0;
for (auto dimSize : oldMemRefType.getShape()) {
if (dimSize == -1)
- allocOperands.push_back(bOuter.create<DimOp>(forStmt->getLoc(), oldMemRef,
+ allocOperands.push_back(bOuter.create<DimOp>(forInst->getLoc(), oldMemRef,
dynamicDimCount++));
}
- // Create and place the alloc right before the 'for' statement.
+ // Create and place the alloc right before the 'for' instruction.
// TODO(mlir-team): we are assuming scoped allocation here, and aren't
// inserting a dealloc -- this isn't the right thing.
Value *newMemRef =
- bOuter.create<AllocOp>(forStmt->getLoc(), newMemRefType, allocOperands);
+ bOuter.create<AllocOp>(forInst->getLoc(), newMemRefType, allocOperands);
// Create 'iv mod 2' value to index the leading dimension.
auto d0 = bInner.getAffineDimExpr(0);
auto modTwoMap =
bInner.getAffineMap(/*dimCount=*/1, /*symbolCount=*/0, {d0 % 2}, {});
auto ivModTwoOp =
- bInner.create<AffineApplyOp>(forStmt->getLoc(), modTwoMap, forStmt);
+ bInner.create<AffineApplyOp>(forInst->getLoc(), modTwoMap, forInst);
- // replaceAllMemRefUsesWith will always succeed unless the forStmt body has
+ // replaceAllMemRefUsesWith will always succeed unless the forInst body has
// non-deferencing uses of the memref.
if (!replaceAllMemRefUsesWith(oldMemRef, newMemRef, ivModTwoOp->getResult(0),
AffineMap::Null(), {},
- &*forStmt->getBody()->begin())) {
+ &*forInst->getBody()->begin())) {
LLVM_DEBUG(llvm::dbgs()
<< "memref replacement for double buffering failed\n";);
ivModTwoOp->getInstruction()->erase();
@@ -139,15 +139,15 @@ static bool doubleBuffer(Value *oldMemRef, ForStmt *forStmt) {
/// Returns success if the IR is in a valid state.
PassResult PipelineDataTransfer::runOnMLFunction(Function *f) {
// Do a post order walk so that inner loop DMAs are processed first. This is
- // necessary since 'for' statements nested within would otherwise become
+ // necessary since 'for' instructions nested within would otherwise become
// invalid (erased) when the outer loop is pipelined (the pipelined one gets
// deleted and replaced by a prologue, a new steady-state loop and an
// epilogue).
- forStmts.clear();
+ forInsts.clear();
walkPostOrder(f);
bool ret = false;
- for (auto *forStmt : forStmts) {
- ret = ret | runOnForStmt(forStmt);
+ for (auto *forInst : forInsts) {
+ ret = ret | runOnForInst(forInst);
}
return ret ? failure() : success();
}
@@ -176,36 +176,36 @@ static bool checkTagMatch(OpPointer<DmaStartOp> startOp,
return true;
}
-// Identify matching DMA start/finish statements to overlap computation with.
-static void findMatchingStartFinishStmts(
- ForStmt *forStmt,
+// Identify matching DMA start/finish instructions to overlap computation with.
+static void findMatchingStartFinishInsts(
+ ForInst *forInst,
SmallVectorImpl<std::pair<OperationInst *, OperationInst *>>
&startWaitPairs) {
- // Collect outgoing DMA statements - needed to check for dependences below.
+ // Collect outgoing DMA instructions - needed to check for dependences below.
SmallVector<OpPointer<DmaStartOp>, 4> outgoingDmaOps;
- for (auto &stmt : *forStmt->getBody()) {
- auto *opStmt = dyn_cast<OperationInst>(&stmt);
- if (!opStmt)
+ for (auto &inst : *forInst->getBody()) {
+ auto *opInst = dyn_cast<OperationInst>(&inst);
+ if (!opInst)
continue;
OpPointer<DmaStartOp> dmaStartOp;
- if ((dmaStartOp = opStmt->dyn_cast<DmaStartOp>()) &&
+ if ((dmaStartOp = opInst->dyn_cast<DmaStartOp>()) &&
dmaStartOp->isSrcMemorySpaceFaster())
outgoingDmaOps.push_back(dmaStartOp);
}
- SmallVector<OperationInst *, 4> dmaStartStmts, dmaFinishStmts;
- for (auto &stmt : *forStmt->getBody()) {
- auto *opStmt = dyn_cast<OperationInst>(&stmt);
- if (!opStmt)
+ SmallVector<OperationInst *, 4> dmaStartInsts, dmaFinishInsts;
+ for (auto &inst : *forInst->getBody()) {
+ auto *opInst = dyn_cast<OperationInst>(&inst);
+ if (!opInst)
continue;
- // Collect DMA finish statements.
- if (opStmt->isa<DmaWaitOp>()) {
- dmaFinishStmts.push_back(opStmt);
+ // Collect DMA finish instructions.
+ if (opInst->isa<DmaWaitOp>()) {
+ dmaFinishInsts.push_back(opInst);
continue;
}
OpPointer<DmaStartOp> dmaStartOp;
- if (!(dmaStartOp = opStmt->dyn_cast<DmaStartOp>()))
+ if (!(dmaStartOp = opInst->dyn_cast<DmaStartOp>()))
continue;
// Only DMAs incoming into higher memory spaces are pipelined for now.
// TODO(bondhugula): handle outgoing DMA pipelining.
@@ -227,7 +227,7 @@ static void findMatchingStartFinishStmts(
auto *memref = dmaStartOp->getOperand(dmaStartOp->getFasterMemPos());
bool escapingUses = false;
for (const auto &use : memref->getUses()) {
- if (!dominates(*forStmt->getBody()->begin(), *use.getOwner())) {
+ if (!dominates(*forInst->getBody()->begin(), *use.getOwner())) {
LLVM_DEBUG(llvm::dbgs()
<< "can't pipeline: buffer is live out of loop\n";);
escapingUses = true;
@@ -235,15 +235,15 @@ static void findMatchingStartFinishStmts(
}
}
if (!escapingUses)
- dmaStartStmts.push_back(opStmt);
+ dmaStartInsts.push_back(opInst);
}
- // For each start statement, we look for a matching finish statement.
- for (auto *dmaStartStmt : dmaStartStmts) {
- for (auto *dmaFinishStmt : dmaFinishStmts) {
- if (checkTagMatch(dmaStartStmt->cast<DmaStartOp>(),
- dmaFinishStmt->cast<DmaWaitOp>())) {
- startWaitPairs.push_back({dmaStartStmt, dmaFinishStmt});
+ // For each start instruction, we look for a matching finish instruction.
+ for (auto *dmaStartInst : dmaStartInsts) {
+ for (auto *dmaFinishInst : dmaFinishInsts) {
+ if (checkTagMatch(dmaStartInst->cast<DmaStartOp>(),
+ dmaFinishInst->cast<DmaWaitOp>())) {
+ startWaitPairs.push_back({dmaStartInst, dmaFinishInst});
break;
}
}
@@ -251,17 +251,17 @@ static void findMatchingStartFinishStmts(
}
/// Overlap DMA transfers with computation in this loop. If successful,
-/// 'forStmt' is deleted, and a prologue, a new pipelined loop, and epilogue are
+/// 'forInst' is deleted, and a prologue, a new pipelined loop, and epilogue are
/// inserted right before where it was.
-PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) {
- auto mayBeConstTripCount = getConstantTripCount(*forStmt);
+PassResult PipelineDataTransfer::runOnForInst(ForInst *forInst) {
+ auto mayBeConstTripCount = getConstantTripCount(*forInst);
if (!mayBeConstTripCount.hasValue()) {
LLVM_DEBUG(llvm::dbgs() << "unknown trip count loop\n");
return success();
}
SmallVector<std::pair<OperationInst *, OperationInst *>, 4> startWaitPairs;
- findMatchingStartFinishStmts(forStmt, startWaitPairs);
+ findMatchingStartFinishInsts(forInst, startWaitPairs);
if (startWaitPairs.empty()) {
LLVM_DEBUG(llvm::dbgs() << "No dma start/finish pairs\n";);
@@ -269,22 +269,22 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) {
}
// Double the buffers for the higher memory space memref's.
- // Identify memref's to replace by scanning through all DMA start statements.
- // A DMA start statement has two memref's - the one from the higher level of
- // memory hierarchy is the one to double buffer.
+ // Identify memref's to replace by scanning through all DMA start
+ // instructions. A DMA start instruction has two memref's - the one from the
+ // higher level of memory hierarchy is the one to double buffer.
// TODO(bondhugula): check whether double-buffering is even necessary.
// TODO(bondhugula): make this work with different layouts: assuming here that
// the dimension we are adding here for the double buffering is the outermost
// dimension.
for (auto &pair : startWaitPairs) {
- auto *dmaStartStmt = pair.first;
- Value *oldMemRef = dmaStartStmt->getOperand(
- dmaStartStmt->cast<DmaStartOp>()->getFasterMemPos());
- if (!doubleBuffer(oldMemRef, forStmt)) {
+ auto *dmaStartInst = pair.first;
+ Value *oldMemRef = dmaStartInst->getOperand(
+ dmaStartInst->cast<DmaStartOp>()->getFasterMemPos());
+ if (!doubleBuffer(oldMemRef, forInst)) {
// Normally, double buffering should not fail because we already checked
// that there are no uses outside.
LLVM_DEBUG(llvm::dbgs() << "double buffering failed for: \n";);
- LLVM_DEBUG(dmaStartStmt->dump());
+ LLVM_DEBUG(dmaStartInst->dump());
// IR still in a valid state.
return success();
}
@@ -293,80 +293,80 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) {
// operation could have been used on it if it was dynamically shaped in
// order to create the double buffer above)
if (oldMemRef->use_empty())
- if (auto *allocStmt = oldMemRef->getDefiningInst())
- allocStmt->erase();
+ if (auto *allocInst = oldMemRef->getDefiningInst())
+ allocInst->erase();
}
// Double the buffers for tag memrefs.
for (auto &pair : startWaitPairs) {
- auto *dmaFinishStmt = pair.second;
+ auto *dmaFinishInst = pair.second;
Value *oldTagMemRef =
- dmaFinishStmt->getOperand(getTagMemRefPos(*dmaFinishStmt));
- if (!doubleBuffer(oldTagMemRef, forStmt)) {
+ dmaFinishInst->getOperand(getTagMemRefPos(*dmaFinishInst));
+ if (!doubleBuffer(oldTagMemRef, forInst)) {
LLVM_DEBUG(llvm::dbgs() << "tag double buffering failed\n";);
return success();
}
// If the old tag has no more uses, remove its 'dead' alloc if it was
// alloc'ed.
if (oldTagMemRef->use_empty())
- if (auto *allocStmt = oldTagMemRef->getDefiningInst())
- allocStmt->erase();
+ if (auto *allocInst = oldTagMemRef->getDefiningInst())
+ allocInst->erase();
}
- // Double buffering would have invalidated all the old DMA start/wait stmts.
+ // Double buffering would have invalidated all the old DMA start/wait insts.
startWaitPairs.clear();
- findMatchingStartFinishStmts(forStmt, startWaitPairs);
+ findMatchingStartFinishInsts(forInst, startWaitPairs);
- // Store shift for statement for later lookup for AffineApplyOp's.
- DenseMap<const Statement *, unsigned> stmtShiftMap;
+ // Store shift for instruction for later lookup for AffineApplyOp's.
+ DenseMap<const Instruction *, unsigned> instShiftMap;
for (auto &pair : startWaitPairs) {
- auto *dmaStartStmt = pair.first;
- assert(dmaStartStmt->isa<DmaStartOp>());
- stmtShiftMap[dmaStartStmt] = 0;
- // Set shifts for DMA start stmt's affine operand computation slices to 0.
- if (auto *slice = mlir::createAffineComputationSlice(dmaStartStmt)) {
- stmtShiftMap[slice] = 0;
+ auto *dmaStartInst = pair.first;
+ assert(dmaStartInst->isa<DmaStartOp>());
+ instShiftMap[dmaStartInst] = 0;
+ // Set shifts for DMA start inst's affine operand computation slices to 0.
+ if (auto *slice = mlir::createAffineComputationSlice(dmaStartInst)) {
+ instShiftMap[slice] = 0;
} else {
// If a slice wasn't created, the reachable affine_apply op's from its
// operands are the ones that go with it.
- SmallVector<OperationInst *, 4> affineApplyStmts;
- SmallVector<Value *, 4> operands(dmaStartStmt->getOperands());
- getReachableAffineApplyOps(operands, affineApplyStmts);
- for (const auto *stmt : affineApplyStmts) {
- stmtShiftMap[stmt] = 0;
+ SmallVector<OperationInst *, 4> affineApplyInsts;
+ SmallVector<Value *, 4> operands(dmaStartInst->getOperands());
+ getReachableAffineApplyOps(operands, affineApplyInsts);
+ for (const auto *inst : affineApplyInsts) {
+ instShiftMap[inst] = 0;
}
}
}
// Everything else (including compute ops and dma finish) are shifted by one.
- for (const auto &stmt : *forStmt->getBody()) {
- if (stmtShiftMap.find(&stmt) == stmtShiftMap.end()) {
- stmtShiftMap[&stmt] = 1;
+ for (const auto &inst : *forInst->getBody()) {
+ if (instShiftMap.find(&inst) == instShiftMap.end()) {
+ instShiftMap[&inst] = 1;
}
}
// Get shifts stored in map.
- std::vector<uint64_t> shifts(forStmt->getBody()->getInstructions().size());
+ std::vector<uint64_t> shifts(forInst->getBody()->getInstructions().size());
unsigned s = 0;
- for (auto &stmt : *forStmt->getBody()) {
- assert(stmtShiftMap.find(&stmt) != stmtShiftMap.end());
- shifts[s++] = stmtShiftMap[&stmt];
+ for (auto &inst : *forInst->getBody()) {
+ assert(instShiftMap.find(&inst) != instShiftMap.end());
+ shifts[s++] = instShiftMap[&inst];
LLVM_DEBUG(
- // Tagging statements with shifts for debugging purposes.
- if (auto *opStmt = dyn_cast<OperationInst>(&stmt)) {
- FuncBuilder b(opStmt);
- opStmt->setAttr(b.getIdentifier("shift"),
+ // Tagging instructions with shifts for debugging purposes.
+ if (auto *opInst = dyn_cast<OperationInst>(&inst)) {
+ FuncBuilder b(opInst);
+ opInst->setAttr(b.getIdentifier("shift"),
b.getI64IntegerAttr(shifts[s - 1]));
});
}
- if (!isStmtwiseShiftValid(*forStmt, shifts)) {
+ if (!isInstwiseShiftValid(*forInst, shifts)) {
// Violates dependences.
LLVM_DEBUG(llvm::dbgs() << "Shifts invalid - unexpected\n";);
return success();
}
- if (stmtBodySkew(forStmt, shifts)) {
- LLVM_DEBUG(llvm::dbgs() << "stmt body skewing failed - unexpected\n";);
+ if (instBodySkew(forInst, shifts)) {
+ LLVM_DEBUG(llvm::dbgs() << "inst body skewing failed - unexpected\n";);
return success();
}
diff --git a/mlir/lib/Transforms/SimplifyAffineExpr.cpp b/mlir/lib/Transforms/SimplifyAffineExpr.cpp
index 853a814e516..2a643eb690a 100644
--- a/mlir/lib/Transforms/SimplifyAffineExpr.cpp
+++ b/mlir/lib/Transforms/SimplifyAffineExpr.cpp
@@ -21,7 +21,7 @@
#include "mlir/Analysis/AffineStructures.h"
#include "mlir/IR/Function.h"
-#include "mlir/IR/StmtVisitor.h"
+#include "mlir/IR/InstVisitor.h"
#include "mlir/Pass.h"
#include "mlir/Transforms/Passes.h"
@@ -32,12 +32,12 @@ using llvm::report_fatal_error;
namespace {
-/// Simplifies all affine expressions appearing in the operation statements of
+/// Simplifies all affine expressions appearing in the operation instructions of
/// the Function. This is mainly to test the simplifyAffineExpr method.
// TODO(someone): Gradually, extend this to all affine map references found in
// ML functions and CFG functions.
struct SimplifyAffineStructures : public FunctionPass,
- StmtWalker<SimplifyAffineStructures> {
+ InstWalker<SimplifyAffineStructures> {
explicit SimplifyAffineStructures()
: FunctionPass(&SimplifyAffineStructures::passID) {}
@@ -46,8 +46,8 @@ struct SimplifyAffineStructures : public FunctionPass,
// for this yet? TODO(someone).
PassResult runOnCFGFunction(Function *f) override { return success(); }
- void visitIfStmt(IfStmt *ifStmt);
- void visitOperationInst(OperationInst *opStmt);
+ void visitIfInst(IfInst *ifInst);
+ void visitOperationInst(OperationInst *opInst);
static char passID;
};
@@ -70,18 +70,18 @@ static IntegerSet simplifyIntegerSet(IntegerSet set) {
return set;
}
-void SimplifyAffineStructures::visitIfStmt(IfStmt *ifStmt) {
- auto set = ifStmt->getCondition().getIntegerSet();
- ifStmt->setIntegerSet(simplifyIntegerSet(set));
+void SimplifyAffineStructures::visitIfInst(IfInst *ifInst) {
+ auto set = ifInst->getCondition().getIntegerSet();
+ ifInst->setIntegerSet(simplifyIntegerSet(set));
}
-void SimplifyAffineStructures::visitOperationInst(OperationInst *opStmt) {
- for (auto attr : opStmt->getAttrs()) {
+void SimplifyAffineStructures::visitOperationInst(OperationInst *opInst) {
+ for (auto attr : opInst->getAttrs()) {
if (auto mapAttr = attr.second.dyn_cast<AffineMapAttr>()) {
MutableAffineMap mMap(mapAttr.getValue());
mMap.simplify();
auto map = mMap.getAffineMap();
- opStmt->setAttr(attr.first, AffineMapAttr::get(map));
+ opInst->setAttr(attr.first, AffineMapAttr::get(map));
}
}
}
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index a4116667794..6064d1feff3 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -271,7 +271,7 @@ static void processMLFunction(Function *fn,
}
void setInsertionPoint(OperationInst *op) override {
- // Any new operations should be added before this statement.
+ // Any new operations should be added before this instruction.
builder.setInsertionPoint(cast<OperationInst>(op));
}
@@ -280,7 +280,7 @@ static void processMLFunction(Function *fn,
};
GreedyPatternRewriteDriver driver(std::move(patterns));
- fn->walk([&](OperationInst *stmt) { driver.addToWorklist(stmt); });
+ fn->walk([&](OperationInst *inst) { driver.addToWorklist(inst); });
FuncBuilder mlBuilder(fn);
MLFuncRewriter rewriter(driver, mlBuilder);
diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp
index 03b4bb29e19..93039372121 100644
--- a/mlir/lib/Transforms/Utils/LoopUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp
@@ -26,8 +26,8 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/Statements.h"
-#include "mlir/IR/StmtVisitor.h"
+#include "mlir/IR/InstVisitor.h"
+#include "mlir/IR/Instructions.h"
#include "mlir/StandardOps/StandardOps.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/Support/Debug.h"
@@ -38,22 +38,22 @@ using namespace mlir;
/// Returns the upper bound of an unrolled loop with lower bound 'lb' and with
/// the specified trip count, stride, and unroll factor. Returns nullptr when
/// the trip count can't be expressed as an affine expression.
-AffineMap mlir::getUnrolledLoopUpperBound(const ForStmt &forStmt,
+AffineMap mlir::getUnrolledLoopUpperBound(const ForInst &forInst,
unsigned unrollFactor,
FuncBuilder *builder) {
- auto lbMap = forStmt.getLowerBoundMap();
+ auto lbMap = forInst.getLowerBoundMap();
// Single result lower bound map only.
if (lbMap.getNumResults() != 1)
return AffineMap::Null();
// Sometimes, the trip count cannot be expressed as an affine expression.
- auto tripCount = getTripCountExpr(forStmt);
+ auto tripCount = getTripCountExpr(forInst);
if (!tripCount)
return AffineMap::Null();
AffineExpr lb(lbMap.getResult(0));
- unsigned step = forStmt.getStep();
+ unsigned step = forInst.getStep();
auto newUb = lb + (tripCount - tripCount % unrollFactor - 1) * step;
return builder->getAffineMap(lbMap.getNumDims(), lbMap.getNumSymbols(),
@@ -64,122 +64,122 @@ AffineMap mlir::getUnrolledLoopUpperBound(const ForStmt &forStmt,
/// bound 'lb' and with the specified trip count, stride, and unroll factor.
/// Returns an AffinMap with nullptr storage (that evaluates to false)
/// when the trip count can't be expressed as an affine expression.
-AffineMap mlir::getCleanupLoopLowerBound(const ForStmt &forStmt,
+AffineMap mlir::getCleanupLoopLowerBound(const ForInst &forInst,
unsigned unrollFactor,
FuncBuilder *builder) {
- auto lbMap = forStmt.getLowerBoundMap();
+ auto lbMap = forInst.getLowerBoundMap();
// Single result lower bound map only.
if (lbMap.getNumResults() != 1)
return AffineMap::Null();
// Sometimes the trip count cannot be expressed as an affine expression.
- AffineExpr tripCount(getTripCountExpr(forStmt));
+ AffineExpr tripCount(getTripCountExpr(forInst));
if (!tripCount)
return AffineMap::Null();
AffineExpr lb(lbMap.getResult(0));
- unsigned step = forStmt.getStep();
+ unsigned step = forInst.getStep();
auto newLb = lb + (tripCount - tripCount % unrollFactor) * step;
return builder->getAffineMap(lbMap.getNumDims(), lbMap.getNumSymbols(),
{newLb}, {});
}
-/// Promotes the loop body of a forStmt to its containing block if the forStmt
+/// Promotes the loop body of a forInst to its containing block if the forInst
/// was known to have a single iteration. Returns false otherwise.
// TODO(bondhugula): extend this for arbitrary affine bounds.
-bool mlir::promoteIfSingleIteration(ForStmt *forStmt) {
- Optional<uint64_t> tripCount = getConstantTripCount(*forStmt);
+bool mlir::promoteIfSingleIteration(ForInst *forInst) {
+ Optional<uint64_t> tripCount = getConstantTripCount(*forInst);
if (!tripCount.hasValue() || tripCount.getValue() != 1)
return false;
// TODO(mlir-team): there is no builder for a max.
- if (forStmt->getLowerBoundMap().getNumResults() != 1)
+ if (forInst->getLowerBoundMap().getNumResults() != 1)
return false;
// Replaces all IV uses to its single iteration value.
- if (!forStmt->use_empty()) {
- if (forStmt->hasConstantLowerBound()) {
- auto *mlFunc = forStmt->getFunction();
+ if (!forInst->use_empty()) {
+ if (forInst->hasConstantLowerBound()) {
+ auto *mlFunc = forInst->getFunction();
FuncBuilder topBuilder(&mlFunc->getBody()->front());
auto constOp = topBuilder.create<ConstantIndexOp>(
- forStmt->getLoc(), forStmt->getConstantLowerBound());
- forStmt->replaceAllUsesWith(constOp);
+ forInst->getLoc(), forInst->getConstantLowerBound());
+ forInst->replaceAllUsesWith(constOp);
} else {
- const AffineBound lb = forStmt->getLowerBound();
+ const AffineBound lb = forInst->getLowerBound();
SmallVector<Value *, 4> lbOperands(lb.operand_begin(), lb.operand_end());
- FuncBuilder builder(forStmt->getBlock(), Block::iterator(forStmt));
+ FuncBuilder builder(forInst->getBlock(), Block::iterator(forInst));
auto affineApplyOp = builder.create<AffineApplyOp>(
- forStmt->getLoc(), lb.getMap(), lbOperands);
- forStmt->replaceAllUsesWith(affineApplyOp->getResult(0));
+ forInst->getLoc(), lb.getMap(), lbOperands);
+ forInst->replaceAllUsesWith(affineApplyOp->getResult(0));
}
}
- // Move the loop body statements to the loop's containing block.
- auto *block = forStmt->getBlock();
- block->getInstructions().splice(Block::iterator(forStmt),
- forStmt->getBody()->getInstructions());
- forStmt->erase();
+ // Move the loop body instructions to the loop's containing block.
+ auto *block = forInst->getBlock();
+ block->getInstructions().splice(Block::iterator(forInst),
+ forInst->getBody()->getInstructions());
+ forInst->erase();
return true;
}
-/// Promotes all single iteration for stmt's in the Function, i.e., moves
+/// Promotes all single iteration for inst's in the Function, i.e., moves
/// their body into the containing Block.
void mlir::promoteSingleIterationLoops(Function *f) {
// Gathers all innermost loops through a post order pruned walk.
- class LoopBodyPromoter : public StmtWalker<LoopBodyPromoter> {
+ class LoopBodyPromoter : public InstWalker<LoopBodyPromoter> {
public:
- void visitForStmt(ForStmt *forStmt) { promoteIfSingleIteration(forStmt); }
+ void visitForInst(ForInst *forInst) { promoteIfSingleIteration(forInst); }
};
LoopBodyPromoter fsw;
fsw.walkPostOrder(f);
}
-/// Generates a 'for' stmt with the specified lower and upper bounds while
-/// generating the right IV remappings for the shifted statements. The
-/// statement blocks that go into the loop are specified in stmtGroupQueue
+/// Generates a 'for' inst with the specified lower and upper bounds while
+/// generating the right IV remappings for the shifted instructions. The
+/// instruction blocks that go into the loop are specified in instGroupQueue
/// starting from the specified offset, and in that order; the first element of
-/// the pair specifies the shift applied to that group of statements; note that
-/// the shift is multiplied by the loop step before being applied. Returns
+/// the pair specifies the shift applied to that group of instructions; note
+/// that the shift is multiplied by the loop step before being applied. Returns
/// nullptr if the generated loop simplifies to a single iteration one.
-static ForStmt *
+static ForInst *
generateLoop(AffineMap lbMap, AffineMap ubMap,
- const std::vector<std::pair<uint64_t, ArrayRef<Statement *>>>
- &stmtGroupQueue,
- unsigned offset, ForStmt *srcForStmt, FuncBuilder *b) {
- SmallVector<Value *, 4> lbOperands(srcForStmt->getLowerBoundOperands());
- SmallVector<Value *, 4> ubOperands(srcForStmt->getUpperBoundOperands());
+ const std::vector<std::pair<uint64_t, ArrayRef<Instruction *>>>
+ &instGroupQueue,
+ unsigned offset, ForInst *srcForInst, FuncBuilder *b) {
+ SmallVector<Value *, 4> lbOperands(srcForInst->getLowerBoundOperands());
+ SmallVector<Value *, 4> ubOperands(srcForInst->getUpperBoundOperands());
assert(lbMap.getNumInputs() == lbOperands.size());
assert(ubMap.getNumInputs() == ubOperands.size());
- auto *loopChunk = b->createFor(srcForStmt->getLoc(), lbOperands, lbMap,
- ubOperands, ubMap, srcForStmt->getStep());
+ auto *loopChunk = b->createFor(srcForInst->getLoc(), lbOperands, lbMap,
+ ubOperands, ubMap, srcForInst->getStep());
OperationInst::OperandMapTy operandMap;
- for (auto it = stmtGroupQueue.begin() + offset, e = stmtGroupQueue.end();
+ for (auto it = instGroupQueue.begin() + offset, e = instGroupQueue.end();
it != e; ++it) {
uint64_t shift = it->first;
- auto stmts = it->second;
- // All 'same shift' statements get added with their operands being remapped
- // to results of cloned statements, and their IV used remapped.
+ auto insts = it->second;
+ // All 'same shift' instructions get added with their operands being
+ // remapped to results of cloned instructions, and their IV used remapped.
// Generate the remapping if the shift is not zero: remappedIV = newIV -
// shift.
- if (!srcForStmt->use_empty() && shift != 0) {
- auto b = FuncBuilder::getForStmtBodyBuilder(loopChunk);
+ if (!srcForInst->use_empty() && shift != 0) {
+ auto b = FuncBuilder::getForInstBodyBuilder(loopChunk);
auto *ivRemap = b.create<AffineApplyOp>(
- srcForStmt->getLoc(),
+ srcForInst->getLoc(),
b.getSingleDimShiftAffineMap(-static_cast<int64_t>(
- srcForStmt->getStep() * shift)),
+ srcForInst->getStep() * shift)),
loopChunk)
->getResult(0);
- operandMap[srcForStmt] = ivRemap;
+ operandMap[srcForInst] = ivRemap;
} else {
- operandMap[srcForStmt] = loopChunk;
+ operandMap[srcForInst] = loopChunk;
}
- for (auto *stmt : stmts) {
- loopChunk->getBody()->push_back(stmt->clone(operandMap, b->getContext()));
+ for (auto *inst : insts) {
+ loopChunk->getBody()->push_back(inst->clone(operandMap, b->getContext()));
}
}
if (promoteIfSingleIteration(loopChunk))
@@ -187,63 +187,63 @@ generateLoop(AffineMap lbMap, AffineMap ubMap,
return loopChunk;
}
-/// Skew the statements in the body of a 'for' statement with the specified
-/// statement-wise shifts. The shifts are with respect to the original execution
-/// order, and are multiplied by the loop 'step' before being applied. A shift
-/// of zero for each statement will lead to no change.
-// The skewing of statements with respect to one another can be used for example
-// to allow overlap of asynchronous operations (such as DMA communication) with
-// computation, or just relative shifting of statements for better register
-// reuse, locality or parallelism. As such, the shifts are typically expected to
-// be at most of the order of the number of statements. This method should not
-// be used as a substitute for loop distribution/fission.
-// This method uses an algorithm// in time linear in the number of statements in
-// the body of the for loop - (using the 'sweep line' paradigm). This method
+/// Skew the instructions in the body of a 'for' instruction with the specified
+/// instruction-wise shifts. The shifts are with respect to the original
+/// execution order, and are multiplied by the loop 'step' before being applied.
+/// A shift of zero for each instruction will lead to no change.
+// The skewing of instructions with respect to one another can be used for
+// example to allow overlap of asynchronous operations (such as DMA
+// communication) with computation, or just relative shifting of instructions
+// for better register reuse, locality or parallelism. As such, the shifts are
+// typically expected to be at most of the order of the number of instructions.
+// This method should not be used as a substitute for loop distribution/fission.
+// This method uses an algorithm// in time linear in the number of instructions
+// in the body of the for loop - (using the 'sweep line' paradigm). This method
// asserts preservation of SSA dominance. A check for that as well as that for
// memory-based depedence preservation check rests with the users of this
// method.
-UtilResult mlir::stmtBodySkew(ForStmt *forStmt, ArrayRef<uint64_t> shifts,
+UtilResult mlir::instBodySkew(ForInst *forInst, ArrayRef<uint64_t> shifts,
bool unrollPrologueEpilogue) {
- if (forStmt->getBody()->empty())
+ if (forInst->getBody()->empty())
return UtilResult::Success;
// If the trip counts aren't constant, we would need versioning and
// conditional guards (or context information to prevent such versioning). The
// better way to pipeline for such loops is to first tile them and extract
// constant trip count "full tiles" before applying this.
- auto mayBeConstTripCount = getConstantTripCount(*forStmt);
+ auto mayBeConstTripCount = getConstantTripCount(*forInst);
if (!mayBeConstTripCount.hasValue()) {
LLVM_DEBUG(llvm::dbgs() << "non-constant trip count loop\n";);
return UtilResult::Success;
}
uint64_t tripCount = mayBeConstTripCount.getValue();
- assert(isStmtwiseShiftValid(*forStmt, shifts) &&
+ assert(isInstwiseShiftValid(*forInst, shifts) &&
"shifts will lead to an invalid transformation\n");
- int64_t step = forStmt->getStep();
+ int64_t step = forInst->getStep();
- unsigned numChildStmts = forStmt->getBody()->getInstructions().size();
+ unsigned numChildInsts = forInst->getBody()->getInstructions().size();
// Do a linear time (counting) sort for the shifts.
uint64_t maxShift = 0;
- for (unsigned i = 0; i < numChildStmts; i++) {
+ for (unsigned i = 0; i < numChildInsts; i++) {
maxShift = std::max(maxShift, shifts[i]);
}
// Such large shifts are not the typical use case.
- if (maxShift >= numChildStmts) {
- LLVM_DEBUG(llvm::dbgs() << "stmt shifts too large - unexpected\n";);
+ if (maxShift >= numChildInsts) {
+ LLVM_DEBUG(llvm::dbgs() << "inst shifts too large - unexpected\n";);
return UtilResult::Success;
}
- // An array of statement groups sorted by shift amount; each group has all
- // statements with the same shift in the order in which they appear in the
- // body of the 'for' stmt.
- std::vector<std::vector<Statement *>> sortedStmtGroups(maxShift + 1);
+ // An array of instruction groups sorted by shift amount; each group has all
+ // instructions with the same shift in the order in which they appear in the
+ // body of the 'for' inst.
+ std::vector<std::vector<Instruction *>> sortedInstGroups(maxShift + 1);
unsigned pos = 0;
- for (auto &stmt : *forStmt->getBody()) {
+ for (auto &inst : *forInst->getBody()) {
auto shift = shifts[pos++];
- sortedStmtGroups[shift].push_back(&stmt);
+ sortedInstGroups[shift].push_back(&inst);
}
// Unless the shifts have a specific pattern (which actually would be the
@@ -251,40 +251,40 @@ UtilResult mlir::stmtBodySkew(ForStmt *forStmt, ArrayRef<uint64_t> shifts,
// Nevertheless, if 'unrollPrologueEpilogue' is set, we will treat the first
// loop generated as the prologue and the last as epilogue and unroll these
// fully.
- ForStmt *prologue = nullptr;
- ForStmt *epilogue = nullptr;
+ ForInst *prologue = nullptr;
+ ForInst *epilogue = nullptr;
// Do a sweep over the sorted shifts while storing open groups in a
// vector, and generating loop portions as necessary during the sweep. A block
- // of statements is paired with its shift.
- std::vector<std::pair<uint64_t, ArrayRef<Statement *>>> stmtGroupQueue;
+ // of instructions is paired with its shift.
+ std::vector<std::pair<uint64_t, ArrayRef<Instruction *>>> instGroupQueue;
- auto origLbMap = forStmt->getLowerBoundMap();
+ auto origLbMap = forInst->getLowerBoundMap();
uint64_t lbShift = 0;
- FuncBuilder b(forStmt);
- for (uint64_t d = 0, e = sortedStmtGroups.size(); d < e; ++d) {
+ FuncBuilder b(forInst);
+ for (uint64_t d = 0, e = sortedInstGroups.size(); d < e; ++d) {
// If nothing is shifted by d, continue.
- if (sortedStmtGroups[d].empty())
+ if (sortedInstGroups[d].empty())
continue;
- if (!stmtGroupQueue.empty()) {
+ if (!instGroupQueue.empty()) {
assert(d >= 1 &&
"Queue expected to be empty when the first block is found");
// The interval for which the loop needs to be generated here is:
// [lbShift, min(lbShift + tripCount, d)) and the body of the
- // loop needs to have all statements in stmtQueue in that order.
- ForStmt *res;
+ // loop needs to have all instructions in instQueue in that order.
+ ForInst *res;
if (lbShift + tripCount * step < d * step) {
res = generateLoop(
b.getShiftedAffineMap(origLbMap, lbShift),
b.getShiftedAffineMap(origLbMap, lbShift + tripCount * step),
- stmtGroupQueue, 0, forStmt, &b);
- // Entire loop for the queued stmt groups generated, empty it.
- stmtGroupQueue.clear();
+ instGroupQueue, 0, forInst, &b);
+ // Entire loop for the queued inst groups generated, empty it.
+ instGroupQueue.clear();
lbShift += tripCount * step;
} else {
res = generateLoop(b.getShiftedAffineMap(origLbMap, lbShift),
- b.getShiftedAffineMap(origLbMap, d), stmtGroupQueue,
- 0, forStmt, &b);
+ b.getShiftedAffineMap(origLbMap, d), instGroupQueue,
+ 0, forInst, &b);
lbShift = d * step;
}
if (!prologue && res)
@@ -294,24 +294,24 @@ UtilResult mlir::stmtBodySkew(ForStmt *forStmt, ArrayRef<uint64_t> shifts,
// Start of first interval.
lbShift = d * step;
}
- // Augment the list of statements that get into the current open interval.
- stmtGroupQueue.push_back({d, sortedStmtGroups[d]});
+ // Augment the list of instructions that get into the current open interval.
+ instGroupQueue.push_back({d, sortedInstGroups[d]});
}
- // Those statements groups left in the queue now need to be processed (FIFO)
+ // Those instructions groups left in the queue now need to be processed (FIFO)
// and their loops completed.
- for (unsigned i = 0, e = stmtGroupQueue.size(); i < e; ++i) {
- uint64_t ubShift = (stmtGroupQueue[i].first + tripCount) * step;
+ for (unsigned i = 0, e = instGroupQueue.size(); i < e; ++i) {
+ uint64_t ubShift = (instGroupQueue[i].first + tripCount) * step;
epilogue = generateLoop(b.getShiftedAffineMap(origLbMap, lbShift),
b.getShiftedAffineMap(origLbMap, ubShift),
- stmtGroupQueue, i, forStmt, &b);
+ instGroupQueue, i, forInst, &b);
lbShift = ubShift;
if (!prologue)
prologue = epilogue;
}
- // Erase the original for stmt.
- forStmt->erase();
+ // Erase the original for inst.
+ forInst->erase();
if (unrollPrologueEpilogue && prologue)
loopUnrollFull(prologue);
@@ -322,39 +322,39 @@ UtilResult mlir::stmtBodySkew(ForStmt *forStmt, ArrayRef<uint64_t> shifts,
}
/// Unrolls this loop completely.
-bool mlir::loopUnrollFull(ForStmt *forStmt) {
- Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forStmt);
+bool mlir::loopUnrollFull(ForInst *forInst) {
+ Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forInst);
if (mayBeConstantTripCount.hasValue()) {
uint64_t tripCount = mayBeConstantTripCount.getValue();
if (tripCount == 1) {
- return promoteIfSingleIteration(forStmt);
+ return promoteIfSingleIteration(forInst);
}
- return loopUnrollByFactor(forStmt, tripCount);
+ return loopUnrollByFactor(forInst, tripCount);
}
return false;
}
/// Unrolls and jams this loop by the specified factor or by the trip count (if
/// constant) whichever is lower.
-bool mlir::loopUnrollUpToFactor(ForStmt *forStmt, uint64_t unrollFactor) {
- Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forStmt);
+bool mlir::loopUnrollUpToFactor(ForInst *forInst, uint64_t unrollFactor) {
+ Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forInst);
if (mayBeConstantTripCount.hasValue() &&
mayBeConstantTripCount.getValue() < unrollFactor)
- return loopUnrollByFactor(forStmt, mayBeConstantTripCount.getValue());
- return loopUnrollByFactor(forStmt, unrollFactor);
+ return loopUnrollByFactor(forInst, mayBeConstantTripCount.getValue());
+ return loopUnrollByFactor(forInst, unrollFactor);
}
/// Unrolls this loop by the specified factor. Returns true if the loop
/// is successfully unrolled.
-bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) {
+bool mlir::loopUnrollByFactor(ForInst *forInst, uint64_t unrollFactor) {
assert(unrollFactor >= 1 && "unroll factor should be >= 1");
- if (unrollFactor == 1 || forStmt->getBody()->empty())
+ if (unrollFactor == 1 || forInst->getBody()->empty())
return false;
- auto lbMap = forStmt->getLowerBoundMap();
- auto ubMap = forStmt->getUpperBoundMap();
+ auto lbMap = forInst->getLowerBoundMap();
+ auto ubMap = forInst->getUpperBoundMap();
// Loops with max/min expressions won't be unrolled here (the output can't be
// expressed as a Function in the general case). However, the right way to
@@ -365,10 +365,10 @@ bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) {
// Same operand list for lower and upper bound for now.
// TODO(bondhugula): handle bounds with different operand lists.
- if (!forStmt->matchingBoundOperandList())
+ if (!forInst->matchingBoundOperandList())
return false;
- Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forStmt);
+ Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forInst);
// If the trip count is lower than the unroll factor, no unrolled body.
// TODO(bondhugula): option to specify cleanup loop unrolling.
@@ -377,64 +377,64 @@ bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) {
return false;
// Generate the cleanup loop if trip count isn't a multiple of unrollFactor.
- if (getLargestDivisorOfTripCount(*forStmt) % unrollFactor != 0) {
+ if (getLargestDivisorOfTripCount(*forInst) % unrollFactor != 0) {
DenseMap<const Value *, Value *> operandMap;
- FuncBuilder builder(forStmt->getBlock(), ++Block::iterator(forStmt));
- auto *cleanupForStmt = cast<ForStmt>(builder.clone(*forStmt, operandMap));
- auto clLbMap = getCleanupLoopLowerBound(*forStmt, unrollFactor, &builder);
+ FuncBuilder builder(forInst->getBlock(), ++Block::iterator(forInst));
+ auto *cleanupForInst = cast<ForInst>(builder.clone(*forInst, operandMap));
+ auto clLbMap = getCleanupLoopLowerBound(*forInst, unrollFactor, &builder);
assert(clLbMap &&
"cleanup loop lower bound map for single result bound maps can "
"always be determined");
- cleanupForStmt->setLowerBoundMap(clLbMap);
+ cleanupForInst->setLowerBoundMap(clLbMap);
// Promote the loop body up if this has turned into a single iteration loop.
- promoteIfSingleIteration(cleanupForStmt);
+ promoteIfSingleIteration(cleanupForInst);
// Adjust upper bound.
auto unrolledUbMap =
- getUnrolledLoopUpperBound(*forStmt, unrollFactor, &builder);
+ getUnrolledLoopUpperBound(*forInst, unrollFactor, &builder);
assert(unrolledUbMap &&
"upper bound map can alwayys be determined for an unrolled loop "
"with single result bounds");
- forStmt->setUpperBoundMap(unrolledUbMap);
+ forInst->setUpperBoundMap(unrolledUbMap);
}
// Scale the step of loop being unrolled by unroll factor.
- int64_t step = forStmt->getStep();
- forStmt->setStep(step * unrollFactor);
+ int64_t step = forInst->getStep();
+ forInst->setStep(step * unrollFactor);
- // Builder to insert unrolled bodies right after the last statement in the
- // body of 'forStmt'.
- FuncBuilder builder(forStmt->getBody(), forStmt->getBody()->end());
+ // Builder to insert unrolled bodies right after the last instruction in the
+ // body of 'forInst'.
+ FuncBuilder builder(forInst->getBody(), forInst->getBody()->end());
- // Keep a pointer to the last statement in the original block so that we know
- // what to clone (since we are doing this in-place).
- Block::iterator srcBlockEnd = std::prev(forStmt->getBody()->end());
+ // Keep a pointer to the last instruction in the original block so that we
+ // know what to clone (since we are doing this in-place).
+ Block::iterator srcBlockEnd = std::prev(forInst->getBody()->end());
- // Unroll the contents of 'forStmt' (append unrollFactor-1 additional copies).
+ // Unroll the contents of 'forInst' (append unrollFactor-1 additional copies).
for (unsigned i = 1; i < unrollFactor; i++) {
DenseMap<const Value *, Value *> operandMap;
// If the induction variable is used, create a remapping to the value for
// this unrolled instance.
- if (!forStmt->use_empty()) {
+ if (!forInst->use_empty()) {
// iv' = iv + 1/2/3...unrollFactor-1;
auto d0 = builder.getAffineDimExpr(0);
auto bumpMap = builder.getAffineMap(1, 0, {d0 + i * step}, {});
auto *ivUnroll =
- builder.create<AffineApplyOp>(forStmt->getLoc(), bumpMap, forStmt)
+ builder.create<AffineApplyOp>(forInst->getLoc(), bumpMap, forInst)
->getResult(0);
- operandMap[forStmt] = ivUnroll;
+ operandMap[forInst] = ivUnroll;
}
- // Clone the original body of 'forStmt'.
- for (auto it = forStmt->getBody()->begin(); it != std::next(srcBlockEnd);
+ // Clone the original body of 'forInst'.
+ for (auto it = forInst->getBody()->begin(); it != std::next(srcBlockEnd);
it++) {
builder.clone(*it, operandMap);
}
}
// Promote the loop body up if this has turned into a single iteration loop.
- promoteIfSingleIteration(forStmt);
+ promoteIfSingleIteration(forInst);
return true;
}
diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp
index 3661c1bdbbc..8cfe2619e2a 100644
--- a/mlir/lib/Transforms/Utils/Utils.cpp
+++ b/mlir/lib/Transforms/Utils/Utils.cpp
@@ -26,8 +26,8 @@
#include "mlir/Analysis/AffineStructures.h"
#include "mlir/Analysis/Utils.h"
#include "mlir/IR/Builders.h"
+#include "mlir/IR/InstVisitor.h"
#include "mlir/IR/Module.h"
-#include "mlir/IR/StmtVisitor.h"
#include "mlir/StandardOps/StandardOps.h"
#include "mlir/Support/MathExtras.h"
#include "llvm/ADT/DenseMap.h"
@@ -66,7 +66,7 @@ bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef,
ArrayRef<Value *> extraIndices,
AffineMap indexRemap,
ArrayRef<Value *> extraOperands,
- const Statement *domStmtFilter) {
+ const Instruction *domInstFilter) {
unsigned newMemRefRank = newMemRef->getType().cast<MemRefType>().getRank();
(void)newMemRefRank; // unused in opt mode
unsigned oldMemRefRank = oldMemRef->getType().cast<MemRefType>().getRank();
@@ -85,41 +85,41 @@ bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef,
// Walk all uses of old memref. Operation using the memref gets replaced.
for (auto it = oldMemRef->use_begin(); it != oldMemRef->use_end();) {
InstOperand &use = *(it++);
- auto *opStmt = cast<OperationInst>(use.getOwner());
+ auto *opInst = cast<OperationInst>(use.getOwner());
- // Skip this use if it's not dominated by domStmtFilter.
- if (domStmtFilter && !dominates(*domStmtFilter, *opStmt))
+ // Skip this use if it's not dominated by domInstFilter.
+ if (domInstFilter && !dominates(*domInstFilter, *opInst))
continue;
// Check if the memref was used in a non-deferencing context. It is fine for
// the memref to be used in a non-deferencing way outside of the region
// where this replacement is happening.
- if (!isMemRefDereferencingOp(*opStmt))
+ if (!isMemRefDereferencingOp(*opInst))
// Failure: memref used in a non-deferencing op (potentially escapes); no
// replacement in these cases.
return false;
auto getMemRefOperandPos = [&]() -> unsigned {
unsigned i, e;
- for (i = 0, e = opStmt->getNumOperands(); i < e; i++) {
- if (opStmt->getOperand(i) == oldMemRef)
+ for (i = 0, e = opInst->getNumOperands(); i < e; i++) {
+ if (opInst->getOperand(i) == oldMemRef)
break;
}
- assert(i < opStmt->getNumOperands() && "operand guaranteed to be found");
+ assert(i < opInst->getNumOperands() && "operand guaranteed to be found");
return i;
};
unsigned memRefOperandPos = getMemRefOperandPos();
- // Construct the new operation statement using this memref.
- OperationState state(opStmt->getContext(), opStmt->getLoc(),
- opStmt->getName());
- state.operands.reserve(opStmt->getNumOperands() + extraIndices.size());
+ // Construct the new operation instruction using this memref.
+ OperationState state(opInst->getContext(), opInst->getLoc(),
+ opInst->getName());
+ state.operands.reserve(opInst->getNumOperands() + extraIndices.size());
// Insert the non-memref operands.
- state.operands.insert(state.operands.end(), opStmt->operand_begin(),
- opStmt->operand_begin() + memRefOperandPos);
+ state.operands.insert(state.operands.end(), opInst->operand_begin(),
+ opInst->operand_begin() + memRefOperandPos);
state.operands.push_back(newMemRef);
- FuncBuilder builder(opStmt);
+ FuncBuilder builder(opInst);
for (auto *extraIndex : extraIndices) {
// TODO(mlir-team): An operation/SSA value should provide a method to
// return the position of an SSA result in its defining
@@ -139,10 +139,10 @@ bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef,
remapOperands.insert(remapOperands.end(), extraOperands.begin(),
extraOperands.end());
remapOperands.insert(
- remapOperands.end(), opStmt->operand_begin() + memRefOperandPos + 1,
- opStmt->operand_begin() + memRefOperandPos + 1 + oldMemRefRank);
+ remapOperands.end(), opInst->operand_begin() + memRefOperandPos + 1,
+ opInst->operand_begin() + memRefOperandPos + 1 + oldMemRefRank);
if (indexRemap) {
- auto remapOp = builder.create<AffineApplyOp>(opStmt->getLoc(), indexRemap,
+ auto remapOp = builder.create<AffineApplyOp>(opInst->getLoc(), indexRemap,
remapOperands);
// Remapped indices.
for (auto *index : remapOp->getInstruction()->getResults())
@@ -155,27 +155,27 @@ bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef,
// Insert the remaining operands unmodified.
state.operands.insert(state.operands.end(),
- opStmt->operand_begin() + memRefOperandPos + 1 +
+ opInst->operand_begin() + memRefOperandPos + 1 +
oldMemRefRank,
- opStmt->operand_end());
+ opInst->operand_end());
// Result types don't change. Both memref's are of the same elemental type.
- state.types.reserve(opStmt->getNumResults());
- for (const auto *result : opStmt->getResults())
+ state.types.reserve(opInst->getNumResults());
+ for (const auto *result : opInst->getResults())
state.types.push_back(result->getType());
// Attributes also do not change.
- state.attributes.insert(state.attributes.end(), opStmt->getAttrs().begin(),
- opStmt->getAttrs().end());
+ state.attributes.insert(state.attributes.end(), opInst->getAttrs().begin(),
+ opInst->getAttrs().end());
// Create the new operation.
auto *repOp = builder.createOperation(state);
// Replace old memref's deferencing op's uses.
unsigned r = 0;
- for (auto *res : opStmt->getResults()) {
+ for (auto *res : opInst->getResults()) {
res->replaceAllUsesWith(repOp->getResult(r++));
}
- opStmt->erase();
+ opInst->erase();
}
return true;
}
@@ -196,9 +196,9 @@ mlir::createComposedAffineApplyOp(FuncBuilder *builder, Location loc,
// Initialize AffineValueMap with identity map.
AffineValueMap valueMap(map, operands);
- for (auto *opStmt : affineApplyOps) {
- assert(opStmt->isa<AffineApplyOp>());
- auto affineApplyOp = opStmt->cast<AffineApplyOp>();
+ for (auto *opInst : affineApplyOps) {
+ assert(opInst->isa<AffineApplyOp>());
+ auto affineApplyOp = opInst->cast<AffineApplyOp>();
// Forward substitute 'affineApplyOp' into 'valueMap'.
valueMap.forwardSubstitute(*affineApplyOp);
}
@@ -219,10 +219,10 @@ mlir::createComposedAffineApplyOp(FuncBuilder *builder, Location loc,
return affineApplyOp->getInstruction();
}
-/// Given an operation statement, inserts a new single affine apply operation,
-/// that is exclusively used by this operation statement, and that provides all
-/// operands that are results of an affine_apply as a function of loop iterators
-/// and program parameters and whose results are.
+/// Given an operation instruction, inserts a new single affine apply operation,
+/// that is exclusively used by this operation instruction, and that provides
+/// all operands that are results of an affine_apply as a function of loop
+/// iterators and program parameters and whose results are.
///
/// Before
///
@@ -242,18 +242,18 @@ mlir::createComposedAffineApplyOp(FuncBuilder *builder, Location loc,
/// This allows applying different transformations on send and compute (for eg.
/// different shifts/delays).
///
-/// Returns nullptr either if none of opStmt's operands were the result of an
+/// Returns nullptr either if none of opInst's operands were the result of an
/// affine_apply and thus there was no affine computation slice to create, or if
-/// all the affine_apply op's supplying operands to this opStmt do not have any
-/// uses besides this opStmt. Returns the new affine_apply operation statement
+/// all the affine_apply op's supplying operands to this opInst do not have any
+/// uses besides this opInst. Returns the new affine_apply operation instruction
/// otherwise.
-OperationInst *mlir::createAffineComputationSlice(OperationInst *opStmt) {
+OperationInst *mlir::createAffineComputationSlice(OperationInst *opInst) {
// Collect all operands that are results of affine apply ops.
SmallVector<Value *, 4> subOperands;
- subOperands.reserve(opStmt->getNumOperands());
- for (auto *operand : opStmt->getOperands()) {
- auto *defStmt = operand->getDefiningInst();
- if (defStmt && defStmt->isa<AffineApplyOp>()) {
+ subOperands.reserve(opInst->getNumOperands());
+ for (auto *operand : opInst->getOperands()) {
+ auto *defInst = operand->getDefiningInst();
+ if (defInst && defInst->isa<AffineApplyOp>()) {
subOperands.push_back(operand);
}
}
@@ -265,13 +265,13 @@ OperationInst *mlir::createAffineComputationSlice(OperationInst *opStmt) {
if (affineApplyOps.empty())
return nullptr;
- // Check if all uses of the affine apply op's lie only in this op stmt, in
+ // Check if all uses of the affine apply op's lie only in this op inst, in
// which case there would be nothing to do.
bool localized = true;
for (auto *op : affineApplyOps) {
for (auto *result : op->getResults()) {
for (auto &use : result->getUses()) {
- if (use.getOwner() != opStmt) {
+ if (use.getOwner() != opInst) {
localized = false;
break;
}
@@ -281,18 +281,18 @@ OperationInst *mlir::createAffineComputationSlice(OperationInst *opStmt) {
if (localized)
return nullptr;
- FuncBuilder builder(opStmt);
+ FuncBuilder builder(opInst);
SmallVector<Value *, 4> results;
- auto *affineApplyStmt = createComposedAffineApplyOp(
- &builder, opStmt->getLoc(), subOperands, affineApplyOps, &results);
+ auto *affineApplyInst = createComposedAffineApplyOp(
+ &builder, opInst->getLoc(), subOperands, affineApplyOps, &results);
assert(results.size() == subOperands.size() &&
"number of results should be the same as the number of subOperands");
// Construct the new operands that include the results from the composed
// affine apply op above instead of existing ones (subOperands). So, they
- // differ from opStmt's operands only for those operands in 'subOperands', for
+ // differ from opInst's operands only for those operands in 'subOperands', for
// which they will be replaced by the corresponding one from 'results'.
- SmallVector<Value *, 4> newOperands(opStmt->getOperands());
+ SmallVector<Value *, 4> newOperands(opInst->getOperands());
for (unsigned i = 0, e = newOperands.size(); i < e; i++) {
// Replace the subOperands from among the new operands.
unsigned j, f;
@@ -306,10 +306,10 @@ OperationInst *mlir::createAffineComputationSlice(OperationInst *opStmt) {
}
for (unsigned idx = 0, e = newOperands.size(); idx < e; idx++) {
- opStmt->setOperand(idx, newOperands[idx]);
+ opInst->setOperand(idx, newOperands[idx]);
}
- return affineApplyStmt;
+ return affineApplyInst;
}
void mlir::forwardSubstitute(OpPointer<AffineApplyOp> affineApplyOp) {
@@ -317,26 +317,26 @@ void mlir::forwardSubstitute(OpPointer<AffineApplyOp> affineApplyOp) {
// TODO: Support forward substitution for CFG style functions.
return;
}
- auto *opStmt = affineApplyOp->getInstruction();
- // Iterate through all uses of all results of 'opStmt', forward substituting
+ auto *opInst = affineApplyOp->getInstruction();
+ // Iterate through all uses of all results of 'opInst', forward substituting
// into any uses which are AffineApplyOps.
- for (unsigned resultIndex = 0, e = opStmt->getNumResults(); resultIndex < e;
+ for (unsigned resultIndex = 0, e = opInst->getNumResults(); resultIndex < e;
++resultIndex) {
- const Value *result = opStmt->getResult(resultIndex);
+ const Value *result = opInst->getResult(resultIndex);
for (auto it = result->use_begin(); it != result->use_end();) {
InstOperand &use = *(it++);
- auto *useStmt = use.getOwner();
- auto *useOpStmt = dyn_cast<OperationInst>(useStmt);
+ auto *useInst = use.getOwner();
+ auto *useOpInst = dyn_cast<OperationInst>(useInst);
// Skip if use is not AffineApplyOp.
- if (useOpStmt == nullptr || !useOpStmt->isa<AffineApplyOp>())
+ if (useOpInst == nullptr || !useOpInst->isa<AffineApplyOp>())
continue;
- // Advance iterator past 'opStmt' operands which also use 'result'.
- while (it != result->use_end() && it->getOwner() == useStmt)
+ // Advance iterator past 'opInst' operands which also use 'result'.
+ while (it != result->use_end() && it->getOwner() == useInst)
++it;
- FuncBuilder builder(useOpStmt);
+ FuncBuilder builder(useOpInst);
// Initialize AffineValueMap with 'affineApplyOp' which uses 'result'.
- auto oldAffineApplyOp = useOpStmt->cast<AffineApplyOp>();
+ auto oldAffineApplyOp = useOpInst->cast<AffineApplyOp>();
AffineValueMap valueMap(*oldAffineApplyOp);
// Forward substitute 'result' at index 'i' into 'valueMap'.
valueMap.forwardSubstituteSingle(*affineApplyOp, resultIndex);
@@ -348,10 +348,10 @@ void mlir::forwardSubstitute(OpPointer<AffineApplyOp> affineApplyOp) {
operands[i] = valueMap.getOperand(i);
}
auto newAffineApplyOp = builder.create<AffineApplyOp>(
- useOpStmt->getLoc(), valueMap.getAffineMap(), operands);
+ useOpInst->getLoc(), valueMap.getAffineMap(), operands);
// Update all uses to use results from 'newAffineApplyOp'.
- for (unsigned i = 0, e = useOpStmt->getNumResults(); i < e; ++i) {
+ for (unsigned i = 0, e = useOpInst->getNumResults(); i < e; ++i) {
oldAffineApplyOp->getResult(i)->replaceAllUsesWith(
newAffineApplyOp->getResult(i));
}
@@ -364,19 +364,19 @@ void mlir::forwardSubstitute(OpPointer<AffineApplyOp> affineApplyOp) {
/// Folds the specified (lower or upper) bound to a constant if possible
/// considering its operands. Returns false if the folding happens for any of
/// the bounds, true otherwise.
-bool mlir::constantFoldBounds(ForStmt *forStmt) {
- auto foldLowerOrUpperBound = [forStmt](bool lower) {
+bool mlir::constantFoldBounds(ForInst *forInst) {
+ auto foldLowerOrUpperBound = [forInst](bool lower) {
// Check if the bound is already a constant.
- if (lower && forStmt->hasConstantLowerBound())
+ if (lower && forInst->hasConstantLowerBound())
return true;
- if (!lower && forStmt->hasConstantUpperBound())
+ if (!lower && forInst->hasConstantUpperBound())
return true;
// Check to see if each of the operands is the result of a constant. If so,
// get the value. If not, ignore it.
SmallVector<Attribute, 8> operandConstants;
- auto boundOperands = lower ? forStmt->getLowerBoundOperands()
- : forStmt->getUpperBoundOperands();
+ auto boundOperands = lower ? forInst->getLowerBoundOperands()
+ : forInst->getUpperBoundOperands();
for (const auto *operand : boundOperands) {
Attribute operandCst;
if (auto *operandOp = operand->getDefiningInst()) {
@@ -387,7 +387,7 @@ bool mlir::constantFoldBounds(ForStmt *forStmt) {
}
AffineMap boundMap =
- lower ? forStmt->getLowerBoundMap() : forStmt->getUpperBoundMap();
+ lower ? forInst->getLowerBoundMap() : forInst->getUpperBoundMap();
assert(boundMap.getNumResults() >= 1 &&
"bound maps should have at least one result");
SmallVector<Attribute, 4> foldedResults;
@@ -402,8 +402,8 @@ bool mlir::constantFoldBounds(ForStmt *forStmt) {
maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult)
: llvm::APIntOps::smin(maxOrMin, foldedResult);
}
- lower ? forStmt->setConstantLowerBound(maxOrMin.getSExtValue())
- : forStmt->setConstantUpperBound(maxOrMin.getSExtValue());
+ lower ? forInst->setConstantLowerBound(maxOrMin.getSExtValue())
+ : forInst->setConstantUpperBound(maxOrMin.getSExtValue());
// Return false on success.
return false;
@@ -449,11 +449,11 @@ void mlir::remapFunctionAttrs(
if (!fn.isML())
return;
- struct MLFnWalker : public StmtWalker<MLFnWalker> {
+ struct MLFnWalker : public InstWalker<MLFnWalker> {
MLFnWalker(const DenseMap<Attribute, FunctionAttr> &remappingTable)
: remappingTable(remappingTable) {}
- void visitOperationInst(OperationInst *opStmt) {
- remapFunctionAttrs(*opStmt, remappingTable);
+ void visitOperationInst(OperationInst *opInst) {
+ remapFunctionAttrs(*opInst, remappingTable);
}
const DenseMap<Attribute, FunctionAttr> &remappingTable;
diff --git a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp
index 78d048b4778..9aa11682ebb 100644
--- a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp
+++ b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp
@@ -95,20 +95,20 @@ void VectorizerTestPass::testVectorShapeRatio(Function *f) {
SmallVector<int, 8> shape(clTestVectorShapeRatio.begin(),
clTestVectorShapeRatio.end());
auto subVectorType = VectorType::get(shape, Type::getF32(f->getContext()));
- // Only filter statements that operate on a strict super-vector and have one
+ // Only filter instructions that operate on a strict super-vector and have one
// return. This makes testing easier.
- auto filter = [subVectorType](const Statement &stmt) {
- auto *opStmt = dyn_cast<OperationInst>(&stmt);
- if (!opStmt) {
+ auto filter = [subVectorType](const Instruction &inst) {
+ auto *opInst = dyn_cast<OperationInst>(&inst);
+ if (!opInst) {
return false;
}
assert(subVectorType.getElementType() ==
Type::getF32(subVectorType.getContext()) &&
"Only f32 supported for now");
- if (!matcher::operatesOnStrictSuperVectors(*opStmt, subVectorType)) {
+ if (!matcher::operatesOnStrictSuperVectors(*opInst, subVectorType)) {
return false;
}
- if (opStmt->getNumResults() != 1) {
+ if (opInst->getNumResults() != 1) {
return false;
}
return true;
@@ -116,26 +116,26 @@ void VectorizerTestPass::testVectorShapeRatio(Function *f) {
auto pat = Op(filter);
auto matches = pat.match(f);
for (auto m : matches) {
- auto *opStmt = cast<OperationInst>(m.first);
+ auto *opInst = cast<OperationInst>(m.first);
// This is a unit test that only checks and prints shape ratio.
// As a consequence we write only Ops with a single return type for the
// purpose of this test. If we need to test more intricate behavior in the
// future we can always extend.
- auto superVectorType = opStmt->getResult(0)->getType().cast<VectorType>();
+ auto superVectorType = opInst->getResult(0)->getType().cast<VectorType>();
auto ratio = shapeRatio(superVectorType, subVectorType);
if (!ratio.hasValue()) {
- opStmt->emitNote("NOT MATCHED");
+ opInst->emitNote("NOT MATCHED");
} else {
- outs() << "\nmatched: " << *opStmt << " with shape ratio: ";
+ outs() << "\nmatched: " << *opInst << " with shape ratio: ";
interleaveComma(MutableArrayRef<unsigned>(*ratio), outs());
}
}
}
-static std::string toString(Statement *stmt) {
+static std::string toString(Instruction *inst) {
std::string res;
auto os = llvm::raw_string_ostream(res);
- stmt->print(os);
+ inst->print(os);
return res;
}
@@ -144,10 +144,10 @@ static MLFunctionMatches matchTestSlicingOps(Function *f) {
constexpr auto kTestSlicingOpName = "slicing-test-op";
using functional::map;
using matcher::Op;
- // Match all OpStatements with the kTestSlicingOpName name.
- auto filter = [](const Statement &stmt) {
- const auto &opStmt = cast<OperationInst>(stmt);
- return opStmt.getName().getStringRef() == kTestSlicingOpName;
+ // Match all OpInstructions with the kTestSlicingOpName name.
+ auto filter = [](const Instruction &inst) {
+ const auto &opInst = cast<OperationInst>(inst);
+ return opInst.getName().getStringRef() == kTestSlicingOpName;
};
auto pat = Op(filter);
return pat.match(f);
@@ -156,7 +156,7 @@ static MLFunctionMatches matchTestSlicingOps(Function *f) {
void VectorizerTestPass::testBackwardSlicing(Function *f) {
auto matches = matchTestSlicingOps(f);
for (auto m : matches) {
- SetVector<Statement *> backwardSlice;
+ SetVector<Instruction *> backwardSlice;
getBackwardSlice(m.first, &backwardSlice);
auto strs = map(toString, backwardSlice);
outs() << "\nmatched: " << *m.first << " backward static slice: ";
@@ -169,7 +169,7 @@ void VectorizerTestPass::testBackwardSlicing(Function *f) {
void VectorizerTestPass::testForwardSlicing(Function *f) {
auto matches = matchTestSlicingOps(f);
for (auto m : matches) {
- SetVector<Statement *> forwardSlice;
+ SetVector<Instruction *> forwardSlice;
getForwardSlice(m.first, &forwardSlice);
auto strs = map(toString, forwardSlice);
outs() << "\nmatched: " << *m.first << " forward static slice: ";
@@ -182,7 +182,7 @@ void VectorizerTestPass::testForwardSlicing(Function *f) {
void VectorizerTestPass::testSlicing(Function *f) {
auto matches = matchTestSlicingOps(f);
for (auto m : matches) {
- SetVector<Statement *> staticSlice = getSlice(m.first);
+ SetVector<Instruction *> staticSlice = getSlice(m.first);
auto strs = map(toString, staticSlice);
outs() << "\nmatched: " << *m.first << " static slice: ";
for (const auto &s : strs) {
@@ -191,9 +191,9 @@ void VectorizerTestPass::testSlicing(Function *f) {
}
}
-bool customOpWithAffineMapAttribute(const Statement &stmt) {
- const auto &opStmt = cast<OperationInst>(stmt);
- return opStmt.getName().getStringRef() ==
+bool customOpWithAffineMapAttribute(const Instruction &inst) {
+ const auto &opInst = cast<OperationInst>(inst);
+ return opInst.getName().getStringRef() ==
VectorizerTestPass::kTestAffineMapOpName;
}
@@ -205,8 +205,8 @@ void VectorizerTestPass::testComposeMaps(Function *f) {
maps.reserve(matches.size());
std::reverse(matches.begin(), matches.end());
for (auto m : matches) {
- auto *opStmt = cast<OperationInst>(m.first);
- auto map = opStmt->getAttr(VectorizerTestPass::kTestAffineMapAttrName)
+ auto *opInst = cast<OperationInst>(m.first);
+ auto map = opInst->getAttr(VectorizerTestPass::kTestAffineMapAttrName)
.cast<AffineMapAttr>()
.getValue();
maps.push_back(map);
diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp
index ddbd6256782..bbb703cd627 100644
--- a/mlir/lib/Transforms/Vectorize.cpp
+++ b/mlir/lib/Transforms/Vectorize.cpp
@@ -252,7 +252,7 @@ using namespace mlir;
/// ==========
/// The algorithm proceeds in a few steps:
/// 1. defining super-vectorization patterns and matching them on the tree of
-/// ForStmt. A super-vectorization pattern is defined as a recursive data
+/// ForInst. A super-vectorization pattern is defined as a recursive data
/// structures that matches and captures nested, imperfectly-nested loops
/// that have a. comformable loop annotations attached (e.g. parallel,
/// reduction, vectoriable, ...) as well as b. all contiguous load/store
@@ -279,7 +279,7 @@ using namespace mlir;
/// it by its vector form. Otherwise, if the scalar value is a constant,
/// it is vectorized into a splat. In all other cases, vectorization for
/// the pattern currently fails.
-/// e. if everything under the root ForStmt in the current pattern vectorizes
+/// e. if everything under the root ForInst in the current pattern vectorizes
/// properly, we commit that loop to the IR. Otherwise we discard it and
/// restore a previously cloned version of the loop. Thanks to the
/// recursive scoping nature of matchers and captured patterns, this is
@@ -668,12 +668,12 @@ namespace {
struct VectorizationStrategy {
ArrayRef<int> vectorSizes;
- DenseMap<ForStmt *, unsigned> loopToVectorDim;
+ DenseMap<ForInst *, unsigned> loopToVectorDim;
};
} // end anonymous namespace
-static void vectorizeLoopIfProfitable(ForStmt *loop, unsigned depthInPattern,
+static void vectorizeLoopIfProfitable(ForInst *loop, unsigned depthInPattern,
unsigned patternDepth,
VectorizationStrategy *strategy) {
assert(patternDepth > depthInPattern &&
@@ -705,7 +705,7 @@ static bool analyzeProfitability(MLFunctionMatches matches,
unsigned depthInPattern, unsigned patternDepth,
VectorizationStrategy *strategy) {
for (auto m : matches) {
- auto *loop = cast<ForStmt>(m.first);
+ auto *loop = cast<ForInst>(m.first);
bool fail = analyzeProfitability(m.second, depthInPattern + 1, patternDepth,
strategy);
if (fail) {
@@ -721,7 +721,7 @@ static bool analyzeProfitability(MLFunctionMatches matches,
namespace {
struct VectorizationState {
- /// Adds an entry of pre/post vectorization statements in the state.
+ /// Adds an entry of pre/post vectorization instructions in the state.
void registerReplacement(OperationInst *key, OperationInst *value);
/// When the current vectorization pattern is successful, this erases the
/// instructions that were marked for erasure in the proper order and resets
@@ -733,7 +733,7 @@ struct VectorizationState {
SmallVector<OperationInst *, 16> toErase;
// Set of OperationInst that have been vectorized (the values in the
// vectorizationMap for hashed access). The vectorizedSet is used in
- // particular to filter the statements that have already been vectorized by
+ // particular to filter the instructions that have already been vectorized by
// this pattern, when iterating over nested loops in this pattern.
DenseSet<OperationInst *> vectorizedSet;
// Map of old scalar OperationInst to new vectorized OperationInst.
@@ -747,16 +747,16 @@ struct VectorizationState {
// that have been vectorized. They can be retrieved from `vectorizationMap`
// but it is convenient to keep track of them in a separate data structure.
DenseSet<OperationInst *> roots;
- // Terminator statements for the worklist in the vectorizeOperations function.
- // They consist of the subset of store operations that have been vectorized.
- // They can be retrieved from `vectorizationMap` but it is convenient to keep
- // track of them in a separate data structure. Since they do not necessarily
- // belong to use-def chains starting from loads (e.g storing a constant), we
- // need to handle them in a post-pass.
+ // Terminator instructions for the worklist in the vectorizeOperations
+ // function. They consist of the subset of store operations that have been
+ // vectorized. They can be retrieved from `vectorizationMap` but it is
+ // convenient to keep track of them in a separate data structure. Since they
+ // do not necessarily belong to use-def chains starting from loads (e.g
+ // storing a constant), we need to handle them in a post-pass.
DenseSet<OperationInst *> terminators;
- // Checks that the type of `stmt` is StoreOp and adds it to the terminators
+ // Checks that the type of `inst` is StoreOp and adds it to the terminators
// set.
- void registerTerminator(OperationInst *stmt);
+ void registerTerminator(OperationInst *inst);
private:
void registerReplacement(const Value *key, Value *value);
@@ -784,19 +784,19 @@ void VectorizationState::registerReplacement(OperationInst *key,
}
}
-void VectorizationState::registerTerminator(OperationInst *stmt) {
- assert(stmt->isa<StoreOp>() && "terminator must be a StoreOp");
- assert(terminators.count(stmt) == 0 &&
+void VectorizationState::registerTerminator(OperationInst *inst) {
+ assert(inst->isa<StoreOp>() && "terminator must be a StoreOp");
+ assert(terminators.count(inst) == 0 &&
"terminator was already inserted previously");
- terminators.insert(stmt);
+ terminators.insert(inst);
}
void VectorizationState::finishVectorizationPattern() {
while (!toErase.empty()) {
- auto *stmt = toErase.pop_back_val();
+ auto *inst = toErase.pop_back_val();
LLVM_DEBUG(dbgs() << "\n[early-vect] finishVectorizationPattern erase: ");
- LLVM_DEBUG(stmt->print(dbgs()));
- stmt->erase();
+ LLVM_DEBUG(inst->print(dbgs()));
+ inst->erase();
}
}
@@ -832,23 +832,23 @@ static bool vectorizeRootOrTerminal(Value *iv, LoadOrStoreOpPointer memoryOp,
auto vectorType = VectorType::get(state->strategy->vectorSizes, elementType);
// Materialize a MemRef with 1 vector.
- auto *opStmt = memoryOp->getInstruction();
+ auto *opInst = memoryOp->getInstruction();
// For now, vector_transfers must be aligned, operate only on indices with an
// identity subset of AffineMap and do not change layout.
// TODO(ntv): increase the expressiveness power of vector_transfer operations
// as needed by various targets.
- if (opStmt->template isa<LoadOp>()) {
+ if (opInst->template isa<LoadOp>()) {
auto permutationMap =
- makePermutationMap(opStmt, state->strategy->loopToVectorDim);
+ makePermutationMap(opInst, state->strategy->loopToVectorDim);
LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: ");
LLVM_DEBUG(permutationMap.print(dbgs()));
- FuncBuilder b(opStmt);
+ FuncBuilder b(opInst);
auto transfer = b.create<VectorTransferReadOp>(
- opStmt->getLoc(), vectorType, memoryOp->getMemRef(),
+ opInst->getLoc(), vectorType, memoryOp->getMemRef(),
map(makePtrDynCaster<Value>(), memoryOp->getIndices()), permutationMap);
- state->registerReplacement(opStmt, transfer->getInstruction());
+ state->registerReplacement(opInst, transfer->getInstruction());
} else {
- state->registerTerminator(opStmt);
+ state->registerTerminator(opInst);
}
return false;
}
@@ -856,28 +856,29 @@ static bool vectorizeRootOrTerminal(Value *iv, LoadOrStoreOpPointer memoryOp,
/// Coarsens the loops bounds and transforms all remaining load and store
/// operations into the appropriate vector_transfer.
-static bool vectorizeForStmt(ForStmt *loop, int64_t step,
+static bool vectorizeForInst(ForInst *loop, int64_t step,
VectorizationState *state) {
using namespace functional;
loop->setStep(step);
- FilterFunctionType notVectorizedThisPattern = [state](const Statement &stmt) {
- if (!matcher::isLoadOrStore(stmt)) {
- return false;
- }
- auto *opStmt = cast<OperationInst>(&stmt);
- return state->vectorizationMap.count(opStmt) == 0 &&
- state->vectorizedSet.count(opStmt) == 0 &&
- state->roots.count(opStmt) == 0 &&
- state->terminators.count(opStmt) == 0;
- };
+ FilterFunctionType notVectorizedThisPattern =
+ [state](const Instruction &inst) {
+ if (!matcher::isLoadOrStore(inst)) {
+ return false;
+ }
+ auto *opInst = cast<OperationInst>(&inst);
+ return state->vectorizationMap.count(opInst) == 0 &&
+ state->vectorizedSet.count(opInst) == 0 &&
+ state->roots.count(opInst) == 0 &&
+ state->terminators.count(opInst) == 0;
+ };
auto loadAndStores = matcher::Op(notVectorizedThisPattern);
auto matches = loadAndStores.match(loop);
for (auto ls : matches) {
- auto *opStmt = cast<OperationInst>(ls.first);
- auto load = opStmt->dyn_cast<LoadOp>();
- auto store = opStmt->dyn_cast<StoreOp>();
- LLVM_DEBUG(opStmt->print(dbgs()));
+ auto *opInst = cast<OperationInst>(ls.first);
+ auto load = opInst->dyn_cast<LoadOp>();
+ auto store = opInst->dyn_cast<StoreOp>();
+ LLVM_DEBUG(opInst->print(dbgs()));
auto fail = load ? vectorizeRootOrTerminal(loop, load, state)
: vectorizeRootOrTerminal(loop, store, state);
if (fail) {
@@ -895,8 +896,8 @@ static bool vectorizeForStmt(ForStmt *loop, int64_t step,
/// we can build a cost model and a search procedure.
static FilterFunctionType
isVectorizableLoopPtrFactory(unsigned fastestVaryingMemRefDimension) {
- return [fastestVaryingMemRefDimension](const Statement &forStmt) {
- const auto &loop = cast<ForStmt>(forStmt);
+ return [fastestVaryingMemRefDimension](const Instruction &forInst) {
+ const auto &loop = cast<ForInst>(forInst);
return isVectorizableLoopAlongFastestVaryingMemRefDim(
loop, fastestVaryingMemRefDimension);
};
@@ -911,7 +912,7 @@ static bool vectorizeNonRoot(MLFunctionMatches matches,
/// recursively in DFS post-order.
static bool doVectorize(MLFunctionMatches::EntryType oneMatch,
VectorizationState *state) {
- ForStmt *loop = cast<ForStmt>(oneMatch.first);
+ ForInst *loop = cast<ForInst>(oneMatch.first);
MLFunctionMatches childrenMatches = oneMatch.second;
// 1. DFS postorder recursion, if any of my children fails, I fail too.
@@ -938,10 +939,10 @@ static bool doVectorize(MLFunctionMatches::EntryType oneMatch,
// exploratory tradeoffs (see top of the file). Apply coarsening, i.e.:
// | ub -> ub
// | step -> step * vectorSize
- LLVM_DEBUG(dbgs() << "\n[early-vect] vectorizeForStmt by " << vectorSize
+ LLVM_DEBUG(dbgs() << "\n[early-vect] vectorizeForInst by " << vectorSize
<< " : ");
LLVM_DEBUG(loop->print(dbgs()));
- return vectorizeForStmt(loop, loop->getStep() * vectorSize, state);
+ return vectorizeForInst(loop, loop->getStep() * vectorSize, state);
}
/// Non-root pattern iterates over the matches at this level, calls doVectorize
@@ -963,20 +964,20 @@ static bool vectorizeNonRoot(MLFunctionMatches matches,
/// element type.
/// If `type` is not a valid vector type or if the scalar constant is not a
/// valid vector element type, returns nullptr.
-static Value *vectorizeConstant(Statement *stmt, const ConstantOp &constant,
+static Value *vectorizeConstant(Instruction *inst, const ConstantOp &constant,
Type type) {
if (!type || !type.isa<VectorType>() ||
!VectorType::isValidElementType(constant.getType())) {
return nullptr;
}
- FuncBuilder b(stmt);
- Location loc = stmt->getLoc();
+ FuncBuilder b(inst);
+ Location loc = inst->getLoc();
auto vectorType = type.cast<VectorType>();
auto attr = SplatElementsAttr::get(vectorType, constant.getValue());
- auto *constantOpStmt = cast<OperationInst>(constant.getInstruction());
+ auto *constantOpInst = cast<OperationInst>(constant.getInstruction());
OperationState state(
- b.getContext(), loc, constantOpStmt->getName().getStringRef(), {},
+ b.getContext(), loc, constantOpInst->getName().getStringRef(), {},
{vectorType},
{make_pair(Identifier::get("value", b.getContext()), attr)});
@@ -985,7 +986,7 @@ static Value *vectorizeConstant(Statement *stmt, const ConstantOp &constant,
}
/// Returns a uniqu'ed VectorType.
-/// In the case `v`'s defining statement is already part of the `state`'s
+/// In the case `v`'s defining instruction is already part of the `state`'s
/// vectorizedSet, just returns the type of `v`.
/// Otherwise, constructs a new VectorType of shape defined by `state.strategy`
/// and of elemental type the type of `v`.
@@ -993,17 +994,17 @@ static Type getVectorType(Value *v, const VectorizationState &state) {
if (!VectorType::isValidElementType(v->getType())) {
return Type();
}
- auto *definingOpStmt = cast<OperationInst>(v->getDefiningInst());
- if (state.vectorizedSet.count(definingOpStmt) > 0) {
+ auto *definingOpInst = cast<OperationInst>(v->getDefiningInst());
+ if (state.vectorizedSet.count(definingOpInst) > 0) {
return v->getType().cast<VectorType>();
}
return VectorType::get(state.strategy->vectorSizes, v->getType());
};
-/// Tries to vectorize a given operand `op` of Statement `stmt` during def-chain
-/// propagation or during terminator vectorization, by applying the following
-/// logic:
-/// 1. if the defining statement is part of the vectorizedSet (i.e. vectorized
+/// Tries to vectorize a given operand `op` of Instruction `inst` during
+/// def-chain propagation or during terminator vectorization, by applying the
+/// following logic:
+/// 1. if the defining instruction is part of the vectorizedSet (i.e. vectorized
/// useby -def propagation), `op` is already in the proper vector form;
/// 2. otherwise, the `op` may be in some other vector form that fails to
/// vectorize atm (i.e. broadcasting required), returns nullptr to indicate
@@ -1021,13 +1022,13 @@ static Type getVectorType(Value *v, const VectorizationState &state) {
/// vectorization is possible with the above logic. Returns nullptr otherwise.
///
/// TODO(ntv): handle more complex cases.
-static Value *vectorizeOperand(Value *operand, Statement *stmt,
+static Value *vectorizeOperand(Value *operand, Instruction *inst,
VectorizationState *state) {
LLVM_DEBUG(dbgs() << "\n[early-vect]vectorize operand: ");
LLVM_DEBUG(operand->print(dbgs()));
- auto *definingStatement = cast<OperationInst>(operand->getDefiningInst());
+ auto *definingInstruction = cast<OperationInst>(operand->getDefiningInst());
// 1. If this value has already been vectorized this round, we are done.
- if (state->vectorizedSet.count(definingStatement) > 0) {
+ if (state->vectorizedSet.count(definingInstruction) > 0) {
LLVM_DEBUG(dbgs() << " -> already vector operand");
return operand;
}
@@ -1049,7 +1050,7 @@ static Value *vectorizeOperand(Value *operand, Statement *stmt,
}
// 3. vectorize constant.
if (auto constant = operand->getDefiningInst()->dyn_cast<ConstantOp>()) {
- return vectorizeConstant(stmt, *constant,
+ return vectorizeConstant(inst, *constant,
getVectorType(operand, *state).cast<VectorType>());
}
// 4. currently non-vectorizable.
@@ -1068,41 +1069,41 @@ static Value *vectorizeOperand(Value *operand, Statement *stmt,
/// Maybe some Ops are not vectorizable or require some tricky logic, we cannot
/// do one-off logic here; ideally it would be TableGen'd.
static OperationInst *vectorizeOneOperationInst(FuncBuilder *b,
- OperationInst *opStmt,
+ OperationInst *opInst,
VectorizationState *state) {
// Sanity checks.
- assert(!opStmt->isa<LoadOp>() &&
+ assert(!opInst->isa<LoadOp>() &&
"all loads must have already been fully vectorized independently");
- assert(!opStmt->isa<VectorTransferReadOp>() &&
+ assert(!opInst->isa<VectorTransferReadOp>() &&
"vector_transfer_read cannot be further vectorized");
- assert(!opStmt->isa<VectorTransferWriteOp>() &&
+ assert(!opInst->isa<VectorTransferWriteOp>() &&
"vector_transfer_write cannot be further vectorized");
- if (auto store = opStmt->dyn_cast<StoreOp>()) {
+ if (auto store = opInst->dyn_cast<StoreOp>()) {
auto *memRef = store->getMemRef();
auto *value = store->getValueToStore();
- auto *vectorValue = vectorizeOperand(value, opStmt, state);
+ auto *vectorValue = vectorizeOperand(value, opInst, state);
auto indices = map(makePtrDynCaster<Value>(), store->getIndices());
- FuncBuilder b(opStmt);
+ FuncBuilder b(opInst);
auto permutationMap =
- makePermutationMap(opStmt, state->strategy->loopToVectorDim);
+ makePermutationMap(opInst, state->strategy->loopToVectorDim);
LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: ");
LLVM_DEBUG(permutationMap.print(dbgs()));
auto transfer = b.create<VectorTransferWriteOp>(
- opStmt->getLoc(), vectorValue, memRef, indices, permutationMap);
+ opInst->getLoc(), vectorValue, memRef, indices, permutationMap);
auto *res = cast<OperationInst>(transfer->getInstruction());
LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ vectorized store: " << *res);
// "Terminators" (i.e. StoreOps) are erased on the spot.
- opStmt->erase();
+ opInst->erase();
return res;
}
auto types = map([state](Value *v) { return getVectorType(v, *state); },
- opStmt->getResults());
- auto vectorizeOneOperand = [opStmt, state](Value *op) -> Value * {
- return vectorizeOperand(op, opStmt, state);
+ opInst->getResults());
+ auto vectorizeOneOperand = [opInst, state](Value *op) -> Value * {
+ return vectorizeOperand(op, opInst, state);
};
- auto operands = map(vectorizeOneOperand, opStmt->getOperands());
+ auto operands = map(vectorizeOneOperand, opInst->getOperands());
// Check whether a single operand is null. If so, vectorization failed.
bool success = llvm::all_of(operands, [](Value *op) { return op; });
if (!success) {
@@ -1116,9 +1117,9 @@ static OperationInst *vectorizeOneOperationInst(FuncBuilder *b,
// TODO(ntv): Is it worth considering an OperationInst.clone operation
// which changes the type so we can promote an OperationInst with less
// boilerplate?
- OperationState newOp(b->getContext(), opStmt->getLoc(),
- opStmt->getName().getStringRef(), operands, types,
- opStmt->getAttrs());
+ OperationState newOp(b->getContext(), opInst->getLoc(),
+ opInst->getName().getStringRef(), operands, types,
+ opInst->getAttrs());
return b->createOperation(newOp);
}
@@ -1137,13 +1138,13 @@ static bool vectorizeOperations(VectorizationState *state) {
auto insertUsesOf = [&worklist, state](OperationInst *vectorized) {
for (auto *r : vectorized->getResults())
for (auto &u : r->getUses()) {
- auto *stmt = cast<OperationInst>(u.getOwner());
+ auto *inst = cast<OperationInst>(u.getOwner());
// Don't propagate to terminals, a separate pass is needed for those.
// TODO(ntv)[b/119759136]: use isa<> once Op is implemented.
- if (state->terminators.count(stmt) > 0) {
+ if (state->terminators.count(inst) > 0) {
continue;
}
- worklist.insert(stmt);
+ worklist.insert(inst);
}
};
apply(insertUsesOf, state->roots);
@@ -1152,15 +1153,15 @@ static bool vectorizeOperations(VectorizationState *state) {
// size again. By construction, the order of elements in the worklist is
// consistent across iterations.
for (unsigned i = 0; i < worklist.size(); ++i) {
- auto *stmt = worklist[i];
+ auto *inst = worklist[i];
LLVM_DEBUG(dbgs() << "\n[early-vect] vectorize use: ");
- LLVM_DEBUG(stmt->print(dbgs()));
+ LLVM_DEBUG(inst->print(dbgs()));
- // 2. Create vectorized form of the statement.
- // Insert it just before stmt, on success register stmt as replaced.
- FuncBuilder b(stmt);
- auto *vectorizedStmt = vectorizeOneOperationInst(&b, stmt, state);
- if (!vectorizedStmt) {
+ // 2. Create vectorized form of the instruction.
+ // Insert it just before inst, on success register inst as replaced.
+ FuncBuilder b(inst);
+ auto *vectorizedInst = vectorizeOneOperationInst(&b, inst, state);
+ if (!vectorizedInst) {
return true;
}
@@ -1168,11 +1169,11 @@ static bool vectorizeOperations(VectorizationState *state) {
// Note that we cannot just call replaceAllUsesWith because it may
// result in ops with mixed types, for ops whose operands have not all
// yet been vectorized. This would be invalid IR.
- state->registerReplacement(stmt, vectorizedStmt);
+ state->registerReplacement(inst, vectorizedInst);
- // 4. Augment the worklist with uses of the statement we just vectorized.
+ // 4. Augment the worklist with uses of the instruction we just vectorized.
// This preserves the proper order in the worklist.
- apply(insertUsesOf, ArrayRef<OperationInst *>{stmt});
+ apply(insertUsesOf, ArrayRef<OperationInst *>{inst});
}
return false;
}
@@ -1184,7 +1185,7 @@ static bool vectorizeOperations(VectorizationState *state) {
static bool vectorizeRootMatches(MLFunctionMatches matches,
VectorizationStrategy *strategy) {
for (auto m : matches) {
- auto *loop = cast<ForStmt>(m.first);
+ auto *loop = cast<ForInst>(m.first);
VectorizationState state;
state.strategy = strategy;
@@ -1201,7 +1202,7 @@ static bool vectorizeRootMatches(MLFunctionMatches matches,
}
FuncBuilder builder(loop); // builder to insert in place of loop
DenseMap<const Value *, Value *> nomap;
- ForStmt *clonedLoop = cast<ForStmt>(builder.clone(*loop, nomap));
+ ForInst *clonedLoop = cast<ForInst>(builder.clone(*loop, nomap));
auto fail = doVectorize(m, &state);
/// Sets up error handling for this root loop. This is how the root match
/// maintains a clone for handling failure and restores the proper state via
@@ -1230,8 +1231,8 @@ static bool vectorizeRootMatches(MLFunctionMatches matches,
auto roots = map(getDefiningInst, map(getKey, state.replacementMap));
// Vectorize the root operations and everything reached by use-def chains
- // except the terminators (store statements) that need to be post-processed
- // separately.
+ // except the terminators (store instructions) that need to be
+ // post-processed separately.
fail = vectorizeOperations(&state);
if (fail) {
LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ failed vectorizeOperations");
@@ -1239,12 +1240,12 @@ static bool vectorizeRootMatches(MLFunctionMatches matches,
}
// Finally, vectorize the terminators. If anything fails to vectorize, skip.
- auto vectorizeOrFail = [&fail, &state](OperationInst *stmt) {
+ auto vectorizeOrFail = [&fail, &state](OperationInst *inst) {
if (fail) {
return;
}
- FuncBuilder b(stmt);
- auto *res = vectorizeOneOperationInst(&b, stmt, &state);
+ FuncBuilder b(inst);
+ auto *res = vectorizeOneOperationInst(&b, inst, &state);
if (res == nullptr) {
fail = true;
}
@@ -1284,7 +1285,7 @@ PassResult Vectorize::runOnMLFunction(Function *f) {
if (fail) {
continue;
}
- auto *loop = cast<ForStmt>(m.first);
+ auto *loop = cast<ForInst>(m.first);
vectorizeLoopIfProfitable(loop, 0, patternDepth, &strategy);
// TODO(ntv): if pattern does not apply, report it; alter the
// cost/benefit.
OpenPOWER on IntegriCloud