summaryrefslogtreecommitdiffstats
path: root/mlir/lib
diff options
context:
space:
mode:
authorChris Lattner <clattner@google.com>2018-12-27 21:21:41 -0800
committerjpienaar <jpienaar@google.com>2019-03-29 14:42:23 -0700
commit5187cfcf03d36fcd9a08adb768d0bc584ef9e50d (patch)
treea78a2e7454c02452df8370b107a1c1ed336bad64 /mlir/lib
parent3b021d7f2e6bfd42593af76c02d2aa9c26beaaf0 (diff)
downloadbcm5719-llvm-5187cfcf03d36fcd9a08adb768d0bc584ef9e50d.tar.gz
bcm5719-llvm-5187cfcf03d36fcd9a08adb768d0bc584ef9e50d.zip
Merge Operation into OperationInst and standardize nomenclature around
OperationInst. This is a big mechanical patch. This is step 16/n towards merging instructions and statements, NFC. PiperOrigin-RevId: 227093712
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Analysis/AffineAnalysis.cpp14
-rw-r--r--mlir/lib/Analysis/AffineStructures.cpp2
-rw-r--r--mlir/lib/Analysis/LoopAnalysis.cpp14
-rw-r--r--mlir/lib/Analysis/MLFunctionMatcher.cpp4
-rw-r--r--mlir/lib/Analysis/MemRefBoundCheck.cpp4
-rw-r--r--mlir/lib/Analysis/MemRefDependenceCheck.cpp10
-rw-r--r--mlir/lib/Analysis/OpStats.cpp4
-rw-r--r--mlir/lib/Analysis/SliceAnalysis.cpp6
-rw-r--r--mlir/lib/Analysis/Utils.cpp6
-rw-r--r--mlir/lib/Analysis/VectorAnalysis.cpp6
-rw-r--r--mlir/lib/Analysis/Verifier.cpp24
-rw-r--r--mlir/lib/IR/AsmPrinter.cpp40
-rw-r--r--mlir/lib/IR/Builders.cpp2
-rw-r--r--mlir/lib/IR/BuiltinOps.cpp24
-rw-r--r--mlir/lib/IR/Function.cpp24
-rw-r--r--mlir/lib/IR/MLIRContext.cpp2
-rw-r--r--mlir/lib/IR/Operation.cpp256
-rw-r--r--mlir/lib/IR/PatternMatch.cpp15
-rw-r--r--mlir/lib/IR/Statement.cpp166
-rw-r--r--mlir/lib/IR/StmtBlock.cpp4
-rw-r--r--mlir/lib/IR/Value.cpp24
-rw-r--r--mlir/lib/Parser/Parser.cpp18
-rw-r--r--mlir/lib/StandardOps/StandardOps.cpp24
-rw-r--r--mlir/lib/SuperVectorOps/SuperVectorOps.cpp8
-rw-r--r--mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp12
-rw-r--r--mlir/lib/Transforms/CSE.cpp18
-rw-r--r--mlir/lib/Transforms/ComposeAffineMaps.cpp6
-rw-r--r--mlir/lib/Transforms/ConstantFold.cpp19
-rw-r--r--mlir/lib/Transforms/ConvertToCFG.cpp12
-rw-r--r--mlir/lib/Transforms/DmaGeneration.cpp4
-rw-r--r--mlir/lib/Transforms/LoopFusion.cpp28
-rw-r--r--mlir/lib/Transforms/LoopUnroll.cpp2
-rw-r--r--mlir/lib/Transforms/LowerVectorTransfers.cpp9
-rw-r--r--mlir/lib/Transforms/MaterializeVectors.cpp35
-rw-r--r--mlir/lib/Transforms/PipelineDataTransfer.cpp20
-rw-r--r--mlir/lib/Transforms/SimplifyAffineExpr.cpp4
-rw-r--r--mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp39
-rw-r--r--mlir/lib/Transforms/Utils/LoopUtils.cpp2
-rw-r--r--mlir/lib/Transforms/Utils/LoweringUtils.cpp2
-rw-r--r--mlir/lib/Transforms/Utils/Utils.cpp33
-rw-r--r--mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp10
-rw-r--r--mlir/lib/Transforms/Vectorize.cpp88
42 files changed, 474 insertions, 570 deletions
diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp
index 78115b974a1..f3fde8bb95f 100644
--- a/mlir/lib/Analysis/AffineAnalysis.cpp
+++ b/mlir/lib/Analysis/AffineAnalysis.cpp
@@ -478,14 +478,14 @@ bool mlir::getFlattenedAffineExprs(
localVarCst);
}
-/// Returns the sequence of AffineApplyOp OperationStmts operation in
+/// Returns the sequence of AffineApplyOp OperationInsts operation in
/// 'affineApplyOps', which are reachable via a search starting from 'operands',
/// and ending at operands which are not defined by AffineApplyOps.
// TODO(andydavis) Add a method to AffineApplyOp which forward substitutes
// the AffineApplyOp into any user AffineApplyOps.
void mlir::getReachableAffineApplyOps(
ArrayRef<Value *> operands,
- SmallVectorImpl<OperationStmt *> &affineApplyOps) {
+ SmallVectorImpl<OperationInst *> &affineApplyOps) {
struct State {
// The ssa value for this node in the DFS traversal.
Value *value;
@@ -499,9 +499,9 @@ void mlir::getReachableAffineApplyOps(
while (!worklist.empty()) {
State &state = worklist.back();
- auto *opStmt = state.value->getDefiningStmt();
- // Note: getDefiningStmt will return nullptr if the operand is not an
- // OperationStmt (i.e. ForStmt), which is a terminator for the search.
+ auto *opStmt = state.value->getDefiningInst();
+ // Note: getDefiningInst will return nullptr if the operand is not an
+ // OperationInst (i.e. ForStmt), which is a terminator for the search.
if (opStmt == nullptr || !opStmt->isa<AffineApplyOp>()) {
worklist.pop_back();
continue;
@@ -531,7 +531,7 @@ void mlir::getReachableAffineApplyOps(
// operands of 'valueMap'.
void mlir::forwardSubstituteReachableOps(AffineValueMap *valueMap) {
// Gather AffineApplyOps reachable from 'indices'.
- SmallVector<OperationStmt *, 4> affineApplyOps;
+ SmallVector<OperationInst *, 4> affineApplyOps;
getReachableAffineApplyOps(valueMap->getOperands(), affineApplyOps);
// Compose AffineApplyOps in 'affineApplyOps'.
for (auto *opStmt : affineApplyOps) {
@@ -842,7 +842,7 @@ addMemRefAccessConstraints(const AffineValueMap &srcAccessMap,
auto *symbol = operands[i];
assert(symbol->isValidSymbol());
// Check if the symbol is a constant.
- if (auto *opStmt = symbol->getDefiningStmt()) {
+ if (auto *opStmt = symbol->getDefiningInst()) {
if (auto constOp = opStmt->dyn_cast<ConstantIndexOp>()) {
dependenceDomain->setIdToConstant(valuePosMap.getSymPos(symbol),
constOp->getValue());
diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp
index bfdaceff7e7..dd564df3017 100644
--- a/mlir/lib/Analysis/AffineStructures.cpp
+++ b/mlir/lib/Analysis/AffineStructures.cpp
@@ -1269,7 +1269,7 @@ bool FlatAffineConstraints::addForStmtDomain(const ForStmt &forStmt) {
addSymbolId(getNumSymbolIds(), const_cast<Value *>(operand));
loc = getNumDimIds() + getNumSymbolIds() - 1;
// Check if the symbol is a constant.
- if (auto *opStmt = operand->getDefiningStmt()) {
+ if (auto *opStmt = operand->getDefiningInst()) {
if (auto constOp = opStmt->dyn_cast<ConstantIndexOp>()) {
setIdToConstant(*operand, constOp->getValue());
}
diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp
index 7213ba5986a..85af39222c4 100644
--- a/mlir/lib/Analysis/LoopAnalysis.cpp
+++ b/mlir/lib/Analysis/LoopAnalysis.cpp
@@ -127,7 +127,7 @@ uint64_t mlir::getLargestDivisorOfTripCount(const ForStmt &forStmt) {
bool mlir::isAccessInvariant(const Value &iv, const Value &index) {
assert(isa<ForStmt>(iv) && "iv must be a ForStmt");
assert(index.getType().isa<IndexType>() && "index must be of IndexType");
- SmallVector<OperationStmt *, 4> affineApplyOps;
+ SmallVector<OperationInst *, 4> affineApplyOps;
getReachableAffineApplyOps({const_cast<Value *>(&index)}, affineApplyOps);
if (affineApplyOps.empty()) {
@@ -234,13 +234,13 @@ static bool isVectorElement(LoadOrStoreOpPointer memoryOp) {
}
static bool isVectorTransferReadOrWrite(const Statement &stmt) {
- const auto *opStmt = cast<OperationStmt>(&stmt);
+ const auto *opStmt = cast<OperationInst>(&stmt);
return opStmt->isa<VectorTransferReadOp>() ||
opStmt->isa<VectorTransferWriteOp>();
}
using VectorizableStmtFun =
- std::function<bool(const ForStmt &, const OperationStmt &)>;
+ std::function<bool(const ForStmt &, const OperationInst &)>;
static bool isVectorizableLoopWithCond(const ForStmt &loop,
VectorizableStmtFun isVectorizableStmt) {
@@ -265,7 +265,7 @@ static bool isVectorizableLoopWithCond(const ForStmt &loop,
auto loadAndStores = matcher::Op(matcher::isLoadOrStore);
auto loadAndStoresMatched = loadAndStores.match(forStmt);
for (auto ls : loadAndStoresMatched) {
- auto *op = cast<OperationStmt>(ls.first);
+ auto *op = cast<OperationInst>(ls.first);
auto load = op->dyn_cast<LoadOp>();
auto store = op->dyn_cast<StoreOp>();
// Only scalar types are considered vectorizable, all load/store must be
@@ -285,7 +285,7 @@ static bool isVectorizableLoopWithCond(const ForStmt &loop,
bool mlir::isVectorizableLoopAlongFastestVaryingMemRefDim(
const ForStmt &loop, unsigned fastestVaryingDim) {
VectorizableStmtFun fun(
- [fastestVaryingDim](const ForStmt &loop, const OperationStmt &op) {
+ [fastestVaryingDim](const ForStmt &loop, const OperationInst &op) {
auto load = op.dyn_cast<LoadOp>();
auto store = op.dyn_cast<StoreOp>();
return load ? isContiguousAccess(loop, *load, fastestVaryingDim)
@@ -297,7 +297,7 @@ bool mlir::isVectorizableLoopAlongFastestVaryingMemRefDim(
bool mlir::isVectorizableLoop(const ForStmt &loop) {
VectorizableStmtFun fun(
// TODO: implement me
- [](const ForStmt &loop, const OperationStmt &op) { return true; });
+ [](const ForStmt &loop, const OperationInst &op) { return true; });
return isVectorizableLoopWithCond(loop, fun);
}
@@ -314,7 +314,7 @@ bool mlir::isStmtwiseShiftValid(const ForStmt &forStmt,
for (const auto &stmt : *forBody) {
// A for or if stmt does not produce any def/results (that are used
// outside).
- if (const auto *opStmt = dyn_cast<OperationStmt>(&stmt)) {
+ if (const auto *opStmt = dyn_cast<OperationInst>(&stmt)) {
for (unsigned i = 0, e = opStmt->getNumResults(); i < e; ++i) {
const Value *result = opStmt->getResult(i);
for (const StmtOperand &use : result->getUses()) {
diff --git a/mlir/lib/Analysis/MLFunctionMatcher.cpp b/mlir/lib/Analysis/MLFunctionMatcher.cpp
index c227aa3fcdd..c03fed5986b 100644
--- a/mlir/lib/Analysis/MLFunctionMatcher.cpp
+++ b/mlir/lib/Analysis/MLFunctionMatcher.cpp
@@ -200,7 +200,7 @@ namespace mlir {
namespace matcher {
MLFunctionMatcher Op(FilterFunctionType filter) {
- return MLFunctionMatcher(Statement::Kind::Operation, {}, filter);
+ return MLFunctionMatcher(Statement::Kind::OperationInst, {}, filter);
}
MLFunctionMatcher If(MLFunctionMatcher child) {
@@ -246,7 +246,7 @@ bool isReductionLoop(const Statement &stmt) {
};
bool isLoadOrStore(const Statement &stmt) {
- const auto *opStmt = dyn_cast<OperationStmt>(&stmt);
+ const auto *opStmt = dyn_cast<OperationInst>(&stmt);
return opStmt && (opStmt->isa<LoadOp>() || opStmt->isa<StoreOp>());
};
diff --git a/mlir/lib/Analysis/MemRefBoundCheck.cpp b/mlir/lib/Analysis/MemRefBoundCheck.cpp
index 995bb466fef..1cb039fe00e 100644
--- a/mlir/lib/Analysis/MemRefBoundCheck.cpp
+++ b/mlir/lib/Analysis/MemRefBoundCheck.cpp
@@ -45,7 +45,7 @@ struct MemRefBoundCheck : public FunctionPass, StmtWalker<MemRefBoundCheck> {
// Not applicable to CFG functions.
PassResult runOnCFGFunction(CFGFunction *f) override { return success(); }
- void visitOperationStmt(OperationStmt *opStmt);
+ void visitOperationInst(OperationInst *opStmt);
static char passID;
};
@@ -58,7 +58,7 @@ FunctionPass *mlir::createMemRefBoundCheckPass() {
return new MemRefBoundCheck();
}
-void MemRefBoundCheck::visitOperationStmt(OperationStmt *opStmt) {
+void MemRefBoundCheck::visitOperationInst(OperationInst *opStmt) {
if (auto loadOp = opStmt->dyn_cast<LoadOp>()) {
boundCheckLoadOrStoreOp(loadOp);
} else if (auto storeOp = opStmt->dyn_cast<StoreOp>()) {
diff --git a/mlir/lib/Analysis/MemRefDependenceCheck.cpp b/mlir/lib/Analysis/MemRefDependenceCheck.cpp
index 7c57a66310a..ec33c619a17 100644
--- a/mlir/lib/Analysis/MemRefDependenceCheck.cpp
+++ b/mlir/lib/Analysis/MemRefDependenceCheck.cpp
@@ -40,7 +40,7 @@ namespace {
/// Checks dependences between all pairs of memref accesses in an MLFunction.
struct MemRefDependenceCheck : public FunctionPass,
StmtWalker<MemRefDependenceCheck> {
- SmallVector<OperationStmt *, 4> loadsAndStores;
+ SmallVector<OperationInst *, 4> loadsAndStores;
explicit MemRefDependenceCheck()
: FunctionPass(&MemRefDependenceCheck::passID) {}
@@ -48,7 +48,7 @@ struct MemRefDependenceCheck : public FunctionPass,
// Not applicable to CFG functions.
PassResult runOnCFGFunction(CFGFunction *f) override { return success(); }
- void visitOperationStmt(OperationStmt *opStmt) {
+ void visitOperationInst(OperationInst *opStmt) {
if (opStmt->isa<LoadOp>() || opStmt->isa<StoreOp>()) {
loadsAndStores.push_back(opStmt);
}
@@ -66,7 +66,7 @@ FunctionPass *mlir::createMemRefDependenceCheckPass() {
// Adds memref access indices 'opIndices' from 'memrefType' to 'access'.
static void addMemRefAccessIndices(
- llvm::iterator_range<Operation::const_operand_iterator> opIndices,
+ llvm::iterator_range<OperationInst::const_operand_iterator> opIndices,
MemRefType memrefType, MemRefAccess *access) {
access->indices.reserve(memrefType.getRank());
for (auto *index : opIndices) {
@@ -75,7 +75,7 @@ static void addMemRefAccessIndices(
}
// Populates 'access' with memref, indices and opstmt from 'loadOrStoreOpStmt'.
-static void getMemRefAccess(const OperationStmt *loadOrStoreOpStmt,
+static void getMemRefAccess(const OperationInst *loadOrStoreOpStmt,
MemRefAccess *access) {
access->opStmt = loadOrStoreOpStmt;
if (auto loadOp = loadOrStoreOpStmt->dyn_cast<LoadOp>()) {
@@ -131,7 +131,7 @@ getDirectionVectorStr(bool ret, unsigned numCommonLoops, unsigned loopNestDepth,
// "source" access and all subsequent "destination" accesses in
// 'loadsAndStores'. Emits the result of the dependence check as a note with
// the source access.
-static void checkDependences(ArrayRef<OperationStmt *> loadsAndStores) {
+static void checkDependences(ArrayRef<OperationInst *> loadsAndStores) {
for (unsigned i = 0, e = loadsAndStores.size(); i < e; ++i) {
auto *srcOpStmt = loadsAndStores[i];
MemRefAccess srcAccess;
diff --git a/mlir/lib/Analysis/OpStats.cpp b/mlir/lib/Analysis/OpStats.cpp
index d9a0edd6d83..cea0c087297 100644
--- a/mlir/lib/Analysis/OpStats.cpp
+++ b/mlir/lib/Analysis/OpStats.cpp
@@ -38,7 +38,7 @@ struct PrintOpStatsPass : public FunctionPass, StmtWalker<PrintOpStatsPass> {
// Process ML functions and operation statments in ML functions.
PassResult runOnMLFunction(MLFunction *function) override;
- void visitOperationStmt(OperationStmt *stmt);
+ void visitOperationInst(OperationInst *stmt);
// Print summary of op stats.
void printSummary();
@@ -69,7 +69,7 @@ PassResult PrintOpStatsPass::runOnCFGFunction(CFGFunction *function) {
return success();
}
-void PrintOpStatsPass::visitOperationStmt(OperationStmt *stmt) {
+void PrintOpStatsPass::visitOperationInst(OperationInst *stmt) {
++opCount[stmt->getName().getStringRef()];
}
diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp
index b7873f8327f..c06bf4df61e 100644
--- a/mlir/lib/Analysis/SliceAnalysis.cpp
+++ b/mlir/lib/Analysis/SliceAnalysis.cpp
@@ -52,7 +52,7 @@ void mlir::getForwardSlice(Statement *stmt,
return;
}
- if (auto *opStmt = dyn_cast<OperationStmt>(stmt)) {
+ if (auto *opStmt = dyn_cast<OperationInst>(stmt)) {
assert(opStmt->getNumResults() <= 1 && "NYI: multiple results");
if (opStmt->getNumResults() > 0) {
for (auto &u : opStmt->getResult(0)->getUses()) {
@@ -102,7 +102,7 @@ void mlir::getBackwardSlice(Statement *stmt,
}
for (auto *operand : stmt->getOperands()) {
- auto *stmt = operand->getDefiningStmt();
+ auto *stmt = operand->getDefiningInst();
if (backwardSlice->count(stmt) == 0) {
getBackwardSlice(stmt, backwardSlice, filter,
/*topLevel=*/false);
@@ -156,7 +156,7 @@ struct DFSState {
} // namespace
static void DFSPostorder(Statement *current, DFSState *state) {
- auto *opStmt = cast<OperationStmt>(current);
+ auto *opStmt = cast<OperationInst>(current);
assert(opStmt->getNumResults() <= 1 && "NYI: multi-result");
if (opStmt->getNumResults() > 0) {
for (auto &u : opStmt->getResult(0)->getUses()) {
diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp
index e6975ac5d09..a63723b333c 100644
--- a/mlir/lib/Analysis/Utils.cpp
+++ b/mlir/lib/Analysis/Utils.cpp
@@ -145,7 +145,7 @@ Optional<int64_t> MemRefRegion::getBoundingConstantSizeAndShape(
//
// TODO(bondhugula): extend this to any other memref dereferencing ops
// (dma_start, dma_wait).
-bool mlir::getMemRefRegion(OperationStmt *opStmt, unsigned loopDepth,
+bool mlir::getMemRefRegion(OperationInst *opStmt, unsigned loopDepth,
MemRefRegion *region) {
OpPointer<LoadOp> loadOp;
OpPointer<StoreOp> storeOp;
@@ -204,7 +204,7 @@ bool mlir::getMemRefRegion(OperationStmt *opStmt, unsigned loopDepth,
auto *symbol = accessValueMap.getOperand(i);
assert(symbol->isValidSymbol());
// Check if the symbol is a constant.
- if (auto *opStmt = symbol->getDefiningStmt()) {
+ if (auto *opStmt = symbol->getDefiningInst()) {
if (auto constOp = opStmt->dyn_cast<ConstantIndexOp>()) {
regionCst->setIdToConstant(*symbol, constOp->getValue());
}
@@ -282,7 +282,7 @@ bool mlir::boundCheckLoadOrStoreOp(LoadOrStoreOpPointer loadOrStoreOp,
std::is_same<LoadOrStoreOpPointer, OpPointer<StoreOp>>::value,
"function argument should be either a LoadOp or a StoreOp");
- OperationStmt *opStmt = cast<OperationStmt>(loadOrStoreOp->getOperation());
+ OperationInst *opStmt = cast<OperationInst>(loadOrStoreOp->getOperation());
MemRefRegion region;
if (!getMemRefRegion(opStmt, /*loopDepth=*/0, &region))
return false;
diff --git a/mlir/lib/Analysis/VectorAnalysis.cpp b/mlir/lib/Analysis/VectorAnalysis.cpp
index ec19194f2fa..cd9451cd5e9 100644
--- a/mlir/lib/Analysis/VectorAnalysis.cpp
+++ b/mlir/lib/Analysis/VectorAnalysis.cpp
@@ -104,7 +104,7 @@ Optional<SmallVector<unsigned, 4>> mlir::shapeRatio(VectorType superVectorType,
/// header file.
static AffineMap makePermutationMap(
MLIRContext *context,
- llvm::iterator_range<Operation::operand_iterator> indices,
+ llvm::iterator_range<OperationInst::operand_iterator> indices,
const DenseMap<ForStmt *, unsigned> &enclosingLoopToVectorDim) {
using functional::makePtrDynCaster;
using functional::map;
@@ -157,7 +157,7 @@ static SetVector<ForStmt *> getEnclosingForStmts(Statement *stmt) {
}
AffineMap
-mlir::makePermutationMap(OperationStmt *opStmt,
+mlir::makePermutationMap(OperationInst *opStmt,
const DenseMap<ForStmt *, unsigned> &loopToVectorDim) {
DenseMap<ForStmt *, unsigned> enclosingLoopToVectorDim;
auto enclosingLoops = getEnclosingForStmts(opStmt);
@@ -178,7 +178,7 @@ mlir::makePermutationMap(OperationStmt *opStmt,
enclosingLoopToVectorDim);
}
-bool mlir::matcher::operatesOnStrictSuperVectors(const OperationStmt &opStmt,
+bool mlir::matcher::operatesOnStrictSuperVectors(const OperationInst &opStmt,
VectorType subVectorType) {
// First, extract the vector type and ditinguish between:
// a. ops that *must* lower a super-vector (i.e. vector_transfer_read,
diff --git a/mlir/lib/Analysis/Verifier.cpp b/mlir/lib/Analysis/Verifier.cpp
index e7abb899a11..e1de6191de6 100644
--- a/mlir/lib/Analysis/Verifier.cpp
+++ b/mlir/lib/Analysis/Verifier.cpp
@@ -51,7 +51,7 @@ namespace {
///
class Verifier {
public:
- bool failure(const Twine &message, const Operation &value) {
+ bool failure(const Twine &message, const OperationInst &value) {
return value.emitError(message);
}
@@ -62,15 +62,15 @@ public:
bool failure(const Twine &message, const BasicBlock &bb) {
// Take the location information for the first instruction in the block.
if (!bb.empty())
- if (auto *op = dyn_cast<OperationStmt>(&bb.front()))
+ if (auto *op = dyn_cast<OperationInst>(&bb.front()))
return failure(message, *op);
// Worst case, fall back to using the function's location.
return failure(message, fn);
}
- bool verifyOperation(const Operation &op);
- bool verifyAttribute(Attribute attr, const Operation &op);
+ bool verifyOperation(const OperationInst &op);
+ bool verifyAttribute(Attribute attr, const OperationInst &op);
protected:
explicit Verifier(const Function &fn) : fn(fn) {}
@@ -82,7 +82,7 @@ private:
} // end anonymous namespace
// Check that function attributes are all well formed.
-bool Verifier::verifyAttribute(Attribute attr, const Operation &op) {
+bool Verifier::verifyAttribute(Attribute attr, const OperationInst &op) {
if (!attr.isOrContainsFunction())
return false;
@@ -109,9 +109,9 @@ bool Verifier::verifyAttribute(Attribute attr, const Operation &op) {
return false;
}
-/// Check the invariants of the specified operation instruction or statement.
-bool Verifier::verifyOperation(const Operation &op) {
- if (op.getOperationFunction() != &fn)
+/// Check the invariants of the specified operation.
+bool Verifier::verifyOperation(const OperationInst &op) {
+ if (op.getFunction() != &fn)
return failure("operation in the wrong function", op);
// Check that operands are non-nil and structurally ok.
@@ -245,7 +245,7 @@ struct MLFuncVerifier : public Verifier, public StmtWalker<MLFuncVerifier> {
MLFuncVerifier(const MLFunction &fn) : Verifier(fn), fn(fn) {}
- void visitOperationStmt(OperationStmt *opStmt) {
+ void visitOperationInst(OperationInst *opStmt) {
hadError |= verifyOperation(*opStmt);
}
@@ -302,14 +302,14 @@ bool MLFuncVerifier::verifyDominance() {
if (!liveValues.count(opValue)) {
stmt.emitError("operand #" + Twine(operandNo) +
" does not dominate this use");
- if (auto *useStmt = opValue->getDefiningStmt())
+ if (auto *useStmt = opValue->getDefiningInst())
useStmt->emitNote("operand defined here");
return true;
}
++operandNo;
}
- if (auto *opStmt = dyn_cast<OperationStmt>(&stmt)) {
+ if (auto *opStmt = dyn_cast<OperationInst>(&stmt)) {
// Operations define values, add them to the hash table.
for (auto *result : opStmt->getResults())
liveValues.insert(result, true);
@@ -344,7 +344,7 @@ bool MLFuncVerifier::verifyReturn() {
return failure(missingReturnMsg, fn);
const auto &stmt = fn.getBody()->getStatements().back();
- if (const auto *op = dyn_cast<OperationStmt>(&stmt)) {
+ if (const auto *op = dyn_cast<OperationInst>(&stmt)) {
if (!op->isReturn())
return failure(missingReturnMsg, fn);
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index c44ce4d4d6c..9f465ab8507 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -120,10 +120,10 @@ private:
void visitStatement(const Statement *stmt);
void visitForStmt(const ForStmt *forStmt);
void visitIfStmt(const IfStmt *ifStmt);
- void visitOperationStmt(const OperationStmt *opStmt);
+ void visitOperationInst(const OperationInst *opStmt);
void visitType(Type type);
void visitAttribute(Attribute attr);
- void visitOperation(const Operation *op);
+ void visitOperation(const OperationInst *op);
DenseMap<AffineMap, int> affineMapIds;
std::vector<AffineMap> affineMapsById;
@@ -161,7 +161,7 @@ void ModuleState::visitAttribute(Attribute attr) {
}
}
-void ModuleState::visitOperation(const Operation *op) {
+void ModuleState::visitOperation(const OperationInst *op) {
// Visit all the types used in the operation.
for (auto *operand : op->getOperands())
visitType(operand->getType());
@@ -212,7 +212,7 @@ void ModuleState::visitForStmt(const ForStmt *forStmt) {
visitStatement(&childStmt);
}
-void ModuleState::visitOperationStmt(const OperationStmt *opStmt) {
+void ModuleState::visitOperationInst(const OperationInst *opStmt) {
for (auto attr : opStmt->getAttrs())
visitAttribute(attr.second);
}
@@ -223,8 +223,8 @@ void ModuleState::visitStatement(const Statement *stmt) {
return visitIfStmt(cast<IfStmt>(stmt));
case Statement::Kind::For:
return visitForStmt(cast<ForStmt>(stmt));
- case Statement::Kind::Operation:
- return visitOperationStmt(cast<OperationStmt>(stmt));
+ case Statement::Kind::OperationInst:
+ return visitOperationInst(cast<OperationInst>(stmt));
default:
return;
}
@@ -944,8 +944,8 @@ class FunctionPrinter : public ModulePrinter, private OpAsmPrinter {
public:
FunctionPrinter(const ModulePrinter &other) : ModulePrinter(other) {}
- void printOperation(const Operation *op);
- void printDefaultOp(const Operation *op);
+ void printOperation(const OperationInst *op);
+ void printDefaultOp(const OperationInst *op);
// Implement OpAsmPrinter.
raw_ostream &getStream() const { return os; }
@@ -983,7 +983,7 @@ protected:
llvm::raw_svector_ostream specialName(specialNameBuffer);
// Give constant integers special names.
- if (auto *op = value->getDefiningOperation()) {
+ if (auto *op = value->getDefiningInst()) {
if (auto intOp = op->dyn_cast<ConstantIntOp>()) {
// i1 constants get special names.
if (intOp->getType().isInteger(1)) {
@@ -1111,7 +1111,7 @@ private:
};
} // end anonymous namespace
-void FunctionPrinter::printOperation(const Operation *op) {
+void FunctionPrinter::printOperation(const OperationInst *op) {
if (op->getNumResults()) {
printValueID(op->getResult(0), /*printResultNo=*/false);
os << " = ";
@@ -1128,7 +1128,7 @@ void FunctionPrinter::printOperation(const Operation *op) {
printDefaultOp(op);
}
-void FunctionPrinter::printDefaultOp(const Operation *op) {
+void FunctionPrinter::printDefaultOp(const OperationInst *op) {
os << '"';
printEscapedString(op->getName().getStringRef(), os);
os << "\"(";
@@ -1172,7 +1172,7 @@ public:
void print(const Instruction *inst);
- void printSuccessorAndUseList(const Operation *term, unsigned index);
+ void printSuccessorAndUseList(const OperationInst *term, unsigned index);
void printBBName(const BasicBlock *block) { os << "bb" << getBBID(block); }
@@ -1302,7 +1302,7 @@ void CFGFunctionPrinter::printBranchOperands(const Range &range) {
os << ')';
}
-void CFGFunctionPrinter::printSuccessorAndUseList(const Operation *term,
+void CFGFunctionPrinter::printSuccessorAndUseList(const OperationInst *term,
unsigned index) {
printBBName(term->getSuccessor(index));
printBranchOperands(term->getSuccessorOperands(index));
@@ -1331,11 +1331,11 @@ public:
// Methods to print ML function statements.
void print(const Statement *stmt);
- void print(const OperationStmt *stmt);
+ void print(const OperationInst *stmt);
void print(const ForStmt *stmt);
void print(const IfStmt *stmt);
void print(const StmtBlock *block);
- void printSuccessorAndUseList(const Operation *term, unsigned index) {
+ void printSuccessorAndUseList(const OperationInst *term, unsigned index) {
assert(false && "MLFunctions do not have terminators with successors.");
}
@@ -1371,7 +1371,7 @@ void MLFunctionPrinter::numberValues() {
// the first result of the operation statements.
struct NumberValuesPass : public StmtWalker<NumberValuesPass> {
NumberValuesPass(MLFunctionPrinter *printer) : printer(printer) {}
- void visitOperationStmt(OperationStmt *stmt) {
+ void visitOperationInst(OperationInst *stmt) {
if (stmt->getNumResults() != 0)
printer->numberValueID(stmt->getResult(0));
}
@@ -1421,8 +1421,8 @@ void MLFunctionPrinter::print(const StmtBlock *block) {
void MLFunctionPrinter::print(const Statement *stmt) {
switch (stmt->getKind()) {
- case Statement::Kind::Operation:
- return print(cast<OperationStmt>(stmt));
+ case Statement::Kind::OperationInst:
+ return print(cast<OperationInst>(stmt));
case Statement::Kind::For:
return print(cast<ForStmt>(stmt));
case Statement::Kind::If:
@@ -1430,7 +1430,7 @@ void MLFunctionPrinter::print(const Statement *stmt) {
}
}
-void MLFunctionPrinter::print(const OperationStmt *stmt) {
+void MLFunctionPrinter::print(const OperationInst *stmt) {
os.indent(numSpaces);
printOperation(stmt);
}
@@ -1580,7 +1580,7 @@ void Value::print(raw_ostream &os) const {
os << "<block argument>\n";
return;
case Value::Kind::StmtResult:
- return getDefiningStmt()->print(os);
+ return getDefiningInst()->print(os);
case Value::Kind::ForStmt:
return cast<ForStmt>(this)->print(os);
}
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 0d3e54364b3..81a3b7c2950 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -290,7 +290,7 @@ StmtBlock *FuncBuilder::createBlock(StmtBlock *insertBefore) {
}
/// Create an operation given the fields represented as an OperationState.
-OperationStmt *FuncBuilder::createOperation(const OperationState &state) {
+OperationInst *FuncBuilder::createOperation(const OperationState &state) {
auto *op = OperationInst::create(state.location, state.name, state.operands,
state.types, state.attributes,
state.successors, context);
diff --git a/mlir/lib/IR/BuiltinOps.cpp b/mlir/lib/IR/BuiltinOps.cpp
index 50ab254dd76..a87ae6b85f0 100644
--- a/mlir/lib/IR/BuiltinOps.cpp
+++ b/mlir/lib/IR/BuiltinOps.cpp
@@ -36,8 +36,8 @@ BuiltinDialect::BuiltinDialect(MLIRContext *context)
addOperations<AffineApplyOp, BranchOp, CondBranchOp, ConstantOp, ReturnOp>();
}
-void mlir::printDimAndSymbolList(Operation::const_operand_iterator begin,
- Operation::const_operand_iterator end,
+void mlir::printDimAndSymbolList(OperationInst::const_operand_iterator begin,
+ OperationInst::const_operand_iterator end,
unsigned numDims, OpAsmPrinter *p) {
*p << '(';
p->printOperands(begin, begin + numDims);
@@ -188,14 +188,12 @@ void BranchOp::print(OpAsmPrinter *p) const {
bool BranchOp::verify() const {
// ML functions do not have branching terminators.
- if (getOperation()->getOperationFunction()->isML())
+ if (getOperation()->getFunction()->isML())
return (emitOpError("cannot occur in a ML function"), true);
return false;
}
-BasicBlock *BranchOp::getDest() const {
- return getOperation()->getSuccessor(0);
-}
+BasicBlock *BranchOp::getDest() { return getOperation()->getSuccessor(0); }
void BranchOp::setDest(BasicBlock *block) {
return getOperation()->setSuccessor(block, 0);
@@ -258,18 +256,18 @@ void CondBranchOp::print(OpAsmPrinter *p) const {
bool CondBranchOp::verify() const {
// ML functions do not have branching terminators.
- if (getOperation()->getOperationFunction()->isML())
+ if (getOperation()->getFunction()->isML())
return (emitOpError("cannot occur in a ML function"), true);
if (!getCondition()->getType().isInteger(1))
return emitOpError("expected condition type was boolean (i1)");
return false;
}
-BasicBlock *CondBranchOp::getTrueDest() const {
+BasicBlock *CondBranchOp::getTrueDest() {
return getOperation()->getSuccessor(trueIndex);
}
-BasicBlock *CondBranchOp::getFalseDest() const {
+BasicBlock *CondBranchOp::getFalseDest() {
return getOperation()->getSuccessor(falseIndex);
}
@@ -399,13 +397,13 @@ void ConstantFloatOp::build(Builder *builder, OperationState *result,
ConstantOp::build(builder, result, builder->getFloatAttr(type, value), type);
}
-bool ConstantFloatOp::isClassFor(const Operation *op) {
+bool ConstantFloatOp::isClassFor(const OperationInst *op) {
return ConstantOp::isClassFor(op) &&
op->getResult(0)->getType().isa<FloatType>();
}
/// ConstantIntOp only matches values whose result type is an IntegerType.
-bool ConstantIntOp::isClassFor(const Operation *op) {
+bool ConstantIntOp::isClassFor(const OperationInst *op) {
return ConstantOp::isClassFor(op) &&
op->getResult(0)->getType().isa<IntegerType>();
}
@@ -427,7 +425,7 @@ void ConstantIntOp::build(Builder *builder, OperationState *result,
}
/// ConstantIndexOp only matches values whose result type is Index.
-bool ConstantIndexOp::isClassFor(const Operation *op) {
+bool ConstantIndexOp::isClassFor(const OperationInst *op) {
return ConstantOp::isClassFor(op) && op->getResult(0)->getType().isIndex();
}
@@ -470,7 +468,7 @@ void ReturnOp::print(OpAsmPrinter *p) const {
}
bool ReturnOp::verify() const {
- auto *function = cast<OperationStmt>(getOperation())->getFunction();
+ auto *function = cast<OperationInst>(getOperation())->getFunction();
// The operand number and types must match the function signature.
const auto &results = function->getType().getResults();
diff --git a/mlir/lib/IR/Function.cpp b/mlir/lib/IR/Function.cpp
index 62f1dca067d..19b137071f4 100644
--- a/mlir/lib/IR/Function.cpp
+++ b/mlir/lib/IR/Function.cpp
@@ -161,34 +161,34 @@ bool Function::emitError(const Twine &message) const {
// MLFunction implementation.
//===----------------------------------------------------------------------===//
-const OperationStmt *MLFunction::getReturnStmt() const {
- return cast<OperationStmt>(&getBody()->back());
+const OperationInst *MLFunction::getReturnStmt() const {
+ return cast<OperationInst>(&getBody()->back());
}
-OperationStmt *MLFunction::getReturnStmt() {
- return cast<OperationStmt>(&getBody()->back());
+OperationInst *MLFunction::getReturnStmt() {
+ return cast<OperationInst>(&getBody()->back());
}
-void MLFunction::walk(std::function<void(OperationStmt *)> callback) {
+void MLFunction::walk(std::function<void(OperationInst *)> callback) {
struct Walker : public StmtWalker<Walker> {
- std::function<void(OperationStmt *)> const &callback;
- Walker(std::function<void(OperationStmt *)> const &callback)
+ std::function<void(OperationInst *)> const &callback;
+ Walker(std::function<void(OperationInst *)> const &callback)
: callback(callback) {}
- void visitOperationStmt(OperationStmt *opStmt) { callback(opStmt); }
+ void visitOperationInst(OperationInst *opStmt) { callback(opStmt); }
};
Walker v(callback);
v.walk(this);
}
-void MLFunction::walkPostOrder(std::function<void(OperationStmt *)> callback) {
+void MLFunction::walkPostOrder(std::function<void(OperationInst *)> callback) {
struct Walker : public StmtWalker<Walker> {
- std::function<void(OperationStmt *)> const &callback;
- Walker(std::function<void(OperationStmt *)> const &callback)
+ std::function<void(OperationInst *)> const &callback;
+ Walker(std::function<void(OperationInst *)> const &callback)
: callback(callback) {}
- void visitOperationStmt(OperationStmt *opStmt) { callback(opStmt); }
+ void visitOperationInst(OperationInst *opStmt) { callback(opStmt); }
};
Walker v(callback);
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index da0bc4b1595..abc3e1cfda4 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -450,7 +450,7 @@ auto MLIRContext::getDiagnosticHandler() const -> DiagnosticHandlerTy {
/// This emits a diagnostic using the registered issue handle if present, or
/// with the default behavior if not. The MLIR compiler should not generally
-/// interact with this, it should use methods on Operation instead.
+/// interact with this, it should use methods on OperationInst instead.
void MLIRContext::emitDiagnostic(Location location, const llvm::Twine &message,
DiagnosticKind kind) const {
// Check to see if we are emitting a diagnostic on a fused location.
diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index 0526c6ea610..6a9b37560db 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -1,4 +1,4 @@
-//===- Operation.cpp - MLIR Operation Class -------------------------------===//
+//===- Operation.cpp - Operation support code -----------------------------===//
//
// Copyright 2019 The MLIR Authors.
//
@@ -15,8 +15,6 @@
// limitations under the License.
// =============================================================================
-#include "mlir/IR/Operation.h"
-#include "AttributeListStorage.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/MLIRContext.h"
@@ -53,200 +51,6 @@ OperationName OperationName::getFromOpaquePointer(void *pointer) {
OpAsmParser::~OpAsmParser() {}
//===----------------------------------------------------------------------===//
-// Operation class
-//===----------------------------------------------------------------------===//
-
-Operation::Operation(OperationName name, ArrayRef<NamedAttribute> attrs,
- Location location, MLIRContext *context)
- : Statement(Kind::Operation, location), name(name) {
- this->attrs = AttributeListStorage::get(attrs, context);
-
-#ifndef NDEBUG
- for (auto elt : attrs)
- assert(elt.second != nullptr && "Attributes cannot have null entries");
-#endif
-}
-
-Operation::~Operation() {}
-
-
-/// Return the function this operation is defined in.
-Function *Operation::getOperationFunction() {
- return llvm::cast<OperationStmt>(this)->getFunction();
-}
-
-/// Return the number of results this operation has.
-unsigned Operation::getNumResults() const {
- return llvm::cast<OperationStmt>(this)->getNumResults();
-}
-
-/// Return the indicated result.
-Value *Operation::getResult(unsigned idx) {
- return llvm::cast<OperationStmt>(this)->getResult(idx);
-}
-
-unsigned Operation::getNumSuccessors() const {
- assert(isTerminator() && "Only terminators have successors.");
- return llvm::cast<OperationStmt>(this)->getNumSuccessors();
-}
-
-unsigned Operation::getNumSuccessorOperands(unsigned index) const {
- assert(isTerminator() && "Only terminators have successors.");
- return llvm::cast<OperationStmt>(this)->getNumSuccessorOperands(index);
-}
-BasicBlock *Operation::getSuccessor(unsigned index) {
- assert(isTerminator() && "Only terminators have successors");
- return llvm::cast<OperationStmt>(this)->getSuccessor(index);
-}
-void Operation::setSuccessor(BasicBlock *block, unsigned index) {
- assert(isTerminator() && "Only terminators have successors");
- llvm::cast<OperationStmt>(this)->setSuccessor(block, index);
-}
-
-void Operation::eraseSuccessorOperand(unsigned succIndex, unsigned opIndex) {
- assert(isTerminator() && "Only terminators have successors");
- return llvm::cast<OperationStmt>(this)->eraseSuccessorOperand(succIndex,
- opIndex);
-}
-auto Operation::getSuccessorOperands(unsigned index) const
- -> llvm::iterator_range<const_operand_iterator> {
- assert(isTerminator() && "Only terminators have successors.");
- unsigned succOperandIndex =
- llvm::cast<OperationStmt>(this)->getSuccessorOperandIndex(index);
- return {const_operand_iterator(this, succOperandIndex),
- const_operand_iterator(this, succOperandIndex +
- getNumSuccessorOperands(index))};
-}
-auto Operation::getSuccessorOperands(unsigned index)
- -> llvm::iterator_range<operand_iterator> {
- assert(isTerminator() && "Only terminators have successors.");
- unsigned succOperandIndex =
- llvm::cast<OperationStmt>(this)->getSuccessorOperandIndex(index);
- return {operand_iterator(this, succOperandIndex),
- operand_iterator(this,
- succOperandIndex + getNumSuccessorOperands(index))};
-}
-
-/// Return true if there are no users of any results of this operation.
-bool Operation::use_empty() const {
- for (auto *result : getResults())
- if (!result->use_empty())
- return false;
- return true;
-}
-
-ArrayRef<NamedAttribute> Operation::getAttrs() const {
- if (!attrs)
- return {};
- return attrs->getElements();
-}
-
-/// If an attribute exists with the specified name, change it to the new
-/// value. Otherwise, add a new attribute with the specified name/value.
-void Operation::setAttr(Identifier name, Attribute value) {
- assert(value && "attributes may never be null");
- auto origAttrs = getAttrs();
-
- SmallVector<NamedAttribute, 8> newAttrs(origAttrs.begin(), origAttrs.end());
- auto *context = getContext();
-
- // If we already have this attribute, replace it.
- for (auto &elt : newAttrs)
- if (elt.first == name) {
- elt.second = value;
- attrs = AttributeListStorage::get(newAttrs, context);
- return;
- }
-
- // Otherwise, add it.
- newAttrs.push_back({name, value});
- attrs = AttributeListStorage::get(newAttrs, context);
-}
-
-/// Remove the attribute with the specified name if it exists. The return
-/// value indicates whether the attribute was present or not.
-auto Operation::removeAttr(Identifier name) -> RemoveResult {
- auto origAttrs = getAttrs();
- for (unsigned i = 0, e = origAttrs.size(); i != e; ++i) {
- if (origAttrs[i].first == name) {
- SmallVector<NamedAttribute, 8> newAttrs;
- newAttrs.reserve(origAttrs.size() - 1);
- newAttrs.append(origAttrs.begin(), origAttrs.begin() + i);
- newAttrs.append(origAttrs.begin() + i + 1, origAttrs.end());
- attrs = AttributeListStorage::get(newAttrs, getContext());
- return RemoveResult::Removed;
- }
- }
- return RemoveResult::NotFound;
-}
-
-/// Emit a note about this operation, reporting up to any diagnostic
-/// handlers that may be listening.
-void Operation::emitNote(const Twine &message) const {
- getContext()->emitDiagnostic(getLoc(), message,
- MLIRContext::DiagnosticKind::Note);
-}
-
-/// Emit a warning about this operation, reporting up to any diagnostic
-/// handlers that may be listening.
-void Operation::emitWarning(const Twine &message) const {
- getContext()->emitDiagnostic(getLoc(), message,
- MLIRContext::DiagnosticKind::Warning);
-}
-
-/// Emit an error about fatal conditions with this operation, reporting up to
-/// any diagnostic handlers that may be listening. This function always returns
-/// true. NOTE: This may terminate the containing application, only use when
-/// the IR is in an inconsistent state.
-bool Operation::emitError(const Twine &message) const {
- return getContext()->emitError(getLoc(), message);
-}
-
-/// Emit an error with the op name prefixed, like "'dim' op " which is
-/// convenient for verifiers.
-bool Operation::emitOpError(const Twine &message) const {
- return emitError(Twine('\'') + getName().getStringRef() + "' op " + message);
-}
-
-/// Remove this operation from its parent block and delete it.
-void Operation::erase() {
- return llvm::cast<OperationStmt>(this)->erase();
-}
-
-/// Attempt to constant fold this operation with the specified constant
-/// operand values. If successful, this returns false and fills in the
-/// results vector. If not, this returns true and results is unspecified.
-bool Operation::constantFold(ArrayRef<Attribute> operands,
- SmallVectorImpl<Attribute> &results) const {
- if (auto *abstractOp = getAbstractOperation()) {
- // If we have a registered operation definition matching this one, use it to
- // try to constant fold the operation.
- if (!abstractOp->constantFoldHook(this, operands, results))
- return false;
-
- // Otherwise, fall back on the dialect hook to handle it.
- return abstractOp->dialect.constantFoldHook(this, operands, results);
- }
-
- // If this operation hasn't been registered or doesn't have abstract
- // operation, fall back to a dialect which matches the prefix.
- auto opName = getName().getStringRef();
- if (auto *dialect = getContext()->getRegisteredDialect(opName)) {
- return dialect->constantFoldHook(this, operands, results);
- }
-
- return true;
-}
-
-/// Methods for support type inquiry through isa, cast, and dyn_cast.
-bool Operation::classof(const Statement *stmt) {
- return stmt->getKind() == Statement::Kind::Operation;
-}
-bool Operation::classof(const IROperandOwner *ptr) {
- return ptr->getKind() == IROperandOwner::Kind::OperationStmt;
-}
-
-//===----------------------------------------------------------------------===//
// OpState trait class.
//===----------------------------------------------------------------------===//
@@ -290,19 +94,20 @@ void OpState::emitNote(const Twine &message) const {
// Op Trait implementations
//===----------------------------------------------------------------------===//
-bool OpTrait::impl::verifyZeroOperands(const Operation *op) {
+bool OpTrait::impl::verifyZeroOperands(const OperationInst *op) {
if (op->getNumOperands() != 0)
return op->emitOpError("requires zero operands");
return false;
}
-bool OpTrait::impl::verifyOneOperand(const Operation *op) {
+bool OpTrait::impl::verifyOneOperand(const OperationInst *op) {
if (op->getNumOperands() != 1)
return op->emitOpError("requires a single operand");
return false;
}
-bool OpTrait::impl::verifyNOperands(const Operation *op, unsigned numOperands) {
+bool OpTrait::impl::verifyNOperands(const OperationInst *op,
+ unsigned numOperands) {
if (op->getNumOperands() != numOperands) {
return op->emitOpError("expected " + Twine(numOperands) +
" operands, but found " +
@@ -311,7 +116,7 @@ bool OpTrait::impl::verifyNOperands(const Operation *op, unsigned numOperands) {
return false;
}
-bool OpTrait::impl::verifyAtLeastNOperands(const Operation *op,
+bool OpTrait::impl::verifyAtLeastNOperands(const OperationInst *op,
unsigned numOperands) {
if (op->getNumOperands() < numOperands)
return op->emitOpError("expected " + Twine(numOperands) +
@@ -331,7 +136,7 @@ static Type getTensorOrVectorElementType(Type type) {
return type;
}
-bool OpTrait::impl::verifyOperandsAreIntegerLike(const Operation *op) {
+bool OpTrait::impl::verifyOperandsAreIntegerLike(const OperationInst *op) {
for (auto *operand : op->getOperands()) {
auto type = getTensorOrVectorElementType(operand->getType());
if (!type.isIntOrIndex())
@@ -340,7 +145,7 @@ bool OpTrait::impl::verifyOperandsAreIntegerLike(const Operation *op) {
return false;
}
-bool OpTrait::impl::verifySameTypeOperands(const Operation *op) {
+bool OpTrait::impl::verifySameTypeOperands(const OperationInst *op) {
// Zero or one operand always have the "same" type.
unsigned nOperands = op->getNumOperands();
if (nOperands < 2)
@@ -354,25 +159,26 @@ bool OpTrait::impl::verifySameTypeOperands(const Operation *op) {
return false;
}
-bool OpTrait::impl::verifyZeroResult(const Operation *op) {
+bool OpTrait::impl::verifyZeroResult(const OperationInst *op) {
if (op->getNumResults() != 0)
return op->emitOpError("requires zero results");
return false;
}
-bool OpTrait::impl::verifyOneResult(const Operation *op) {
+bool OpTrait::impl::verifyOneResult(const OperationInst *op) {
if (op->getNumResults() != 1)
return op->emitOpError("requires one result");
return false;
}
-bool OpTrait::impl::verifyNResults(const Operation *op, unsigned numOperands) {
+bool OpTrait::impl::verifyNResults(const OperationInst *op,
+ unsigned numOperands) {
if (op->getNumResults() != numOperands)
return op->emitOpError("expected " + Twine(numOperands) + " results");
return false;
}
-bool OpTrait::impl::verifyAtLeastNResults(const Operation *op,
+bool OpTrait::impl::verifyAtLeastNResults(const OperationInst *op,
unsigned numOperands) {
if (op->getNumResults() < numOperands)
return op->emitOpError("expected " + Twine(numOperands) +
@@ -401,7 +207,7 @@ static bool verifyShapeMatch(Type type1, Type type2) {
return false;
}
-bool OpTrait::impl::verifySameOperandsAndResultShape(const Operation *op) {
+bool OpTrait::impl::verifySameOperandsAndResultShape(const OperationInst *op) {
if (op->getNumOperands() == 0 || op->getNumResults() == 0)
return true;
@@ -419,7 +225,7 @@ bool OpTrait::impl::verifySameOperandsAndResultShape(const Operation *op) {
return false;
}
-bool OpTrait::impl::verifySameOperandsAndResultType(const Operation *op) {
+bool OpTrait::impl::verifySameOperandsAndResultType(const OperationInst *op) {
if (op->getNumOperands() == 0 || op->getNumResults() == 0)
return true;
@@ -438,8 +244,8 @@ bool OpTrait::impl::verifySameOperandsAndResultType(const Operation *op) {
}
static bool verifyBBArguments(
- llvm::iterator_range<Operation::const_operand_iterator> operands,
- const BasicBlock *destBB, const Operation *op) {
+ llvm::iterator_range<OperationInst::const_operand_iterator> operands,
+ const BasicBlock *destBB, const OperationInst *op) {
unsigned operandCount = std::distance(operands.begin(), operands.end());
if (operandCount != destBB->getNumArguments())
return op->emitError("branch has " + Twine(operandCount) +
@@ -455,9 +261,9 @@ static bool verifyBBArguments(
return false;
}
-static bool verifyTerminatorSuccessors(const Operation *op) {
+static bool verifyTerminatorSuccessors(const OperationInst *op) {
// Verify that the operands lines up with the BB arguments in the successor.
- const Function *fn = op->getOperationFunction();
+ const Function *fn = op->getFunction();
for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i) {
auto *succ = op->getSuccessor(i);
if (succ->getFunction() != fn)
@@ -468,17 +274,15 @@ static bool verifyTerminatorSuccessors(const Operation *op) {
return false;
}
-bool OpTrait::impl::verifyIsTerminator(const Operation *op) {
+bool OpTrait::impl::verifyIsTerminator(const OperationInst *op) {
// Verify that the operation is at the end of the respective parent block.
- if (op->getOperationFunction()->isML()) {
- auto *stmt = cast<OperationStmt>(op);
- StmtBlock *block = stmt->getBlock();
- if (!block || block->getContainingStmt() || &block->back() != stmt)
+ if (op->getFunction()->isML()) {
+ StmtBlock *block = op->getBlock();
+ if (!block || block->getContainingStmt() || &block->back() != op)
return op->emitOpError("must be the last statement in the ML function");
} else {
- auto *inst = cast<OperationInst>(op);
- const BasicBlock *block = inst->getBlock();
- if (!block || &block->back() != inst)
+ const BasicBlock *block = op->getBlock();
+ if (!block || &block->back() != op)
return op->emitOpError(
"must be the last instruction in the parent basic block.");
}
@@ -489,7 +293,7 @@ bool OpTrait::impl::verifyIsTerminator(const Operation *op) {
return false;
}
-bool OpTrait::impl::verifyResultsAreBoolLike(const Operation *op) {
+bool OpTrait::impl::verifyResultsAreBoolLike(const OperationInst *op) {
for (auto *result : op->getResults()) {
auto elementType = getTensorOrVectorElementType(result->getType());
auto intType = elementType.dyn_cast<IntegerType>();
@@ -501,7 +305,7 @@ bool OpTrait::impl::verifyResultsAreBoolLike(const Operation *op) {
return false;
}
-bool OpTrait::impl::verifyResultsAreFloatLike(const Operation *op) {
+bool OpTrait::impl::verifyResultsAreFloatLike(const OperationInst *op) {
for (auto *result : op->getResults()) {
if (!getTensorOrVectorElementType(result->getType()).isa<FloatType>())
return op->emitOpError("requires a floating point type");
@@ -510,7 +314,7 @@ bool OpTrait::impl::verifyResultsAreFloatLike(const Operation *op) {
return false;
}
-bool OpTrait::impl::verifyResultsAreIntegerLike(const Operation *op) {
+bool OpTrait::impl::verifyResultsAreIntegerLike(const OperationInst *op) {
for (auto *result : op->getResults()) {
auto type = getTensorOrVectorElementType(result->getType());
if (!type.isIntOrIndex())
@@ -543,7 +347,7 @@ bool impl::parseBinaryOp(OpAsmParser *parser, OperationState *result) {
parser->addTypeToList(type, result->types);
}
-void impl::printBinaryOp(const Operation *op, OpAsmPrinter *p) {
+void impl::printBinaryOp(const OperationInst *op, OpAsmPrinter *p) {
*p << op->getName() << ' ' << *op->getOperand(0) << ", "
<< *op->getOperand(1);
p->printOptionalAttrDict(op->getAttrs());
@@ -569,7 +373,7 @@ bool impl::parseCastOp(OpAsmParser *parser, OperationState *result) {
parser->addTypeToList(dstType, result->types);
}
-void impl::printCastOp(const Operation *op, OpAsmPrinter *p) {
+void impl::printCastOp(const OperationInst *op, OpAsmPrinter *p) {
*p << op->getName() << ' ' << *op->getOperand(0) << " : "
<< op->getOperand(0)->getType() << " to " << op->getResult(0)->getType();
}
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 9e4d8bb180c..8c41d488a8b 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -58,12 +58,14 @@ void Pattern::anchor() {}
// RewritePattern and PatternRewriter implementation
//===----------------------------------------------------------------------===//
-void RewritePattern::rewrite(Operation *op, std::unique_ptr<PatternState> state,
+void RewritePattern::rewrite(OperationInst *op,
+ std::unique_ptr<PatternState> state,
PatternRewriter &rewriter) const {
rewrite(op, rewriter);
}
-void RewritePattern::rewrite(Operation *op, PatternRewriter &rewriter) const {
+void RewritePattern::rewrite(OperationInst *op,
+ PatternRewriter &rewriter) const {
llvm_unreachable("need to implement one of the rewrite functions!");
}
@@ -77,7 +79,7 @@ PatternRewriter::~PatternRewriter() {
/// clients can specify a list of other nodes that this replacement may make
/// (perhaps transitively) dead. If any of those ops are dead, this will
/// remove them as well.
-void PatternRewriter::replaceOp(Operation *op, ArrayRef<Value *> newValues,
+void PatternRewriter::replaceOp(OperationInst *op, ArrayRef<Value *> newValues,
ArrayRef<Value *> valuesToRemoveIfDead) {
// Notify the rewriter subclass that we're about to replace this root.
notifyRootReplaced(op);
@@ -97,7 +99,8 @@ void PatternRewriter::replaceOp(Operation *op, ArrayRef<Value *> newValues,
/// op and newOp are known to have the same number of results, replace the
/// uses of op with uses of newOp
void PatternRewriter::replaceOpWithResultsOfAnotherOp(
- Operation *op, Operation *newOp, ArrayRef<Value *> valuesToRemoveIfDead) {
+ OperationInst *op, OperationInst *newOp,
+ ArrayRef<Value *> valuesToRemoveIfDead) {
assert(op->getNumResults() == newOp->getNumResults() &&
"replacement op doesn't match results of original op");
if (op->getNumResults() == 1)
@@ -117,7 +120,7 @@ void PatternRewriter::replaceOpWithResultsOfAnotherOp(
/// should remove if they are dead at this point.
///
void PatternRewriter::updatedRootInPlace(
- Operation *op, ArrayRef<Value *> valuesToRemoveIfDead) {
+ OperationInst *op, ArrayRef<Value *> valuesToRemoveIfDead) {
// Notify the rewriter subclass that we're about to replace this root.
notifyRootUpdated(op);
@@ -132,7 +135,7 @@ void PatternRewriter::updatedRootInPlace(
/// Find the highest benefit pattern available in the pattern set for the DAG
/// rooted at the specified node. This returns the pattern if found, or null
/// if there are no matches.
-auto PatternMatcher::findMatch(Operation *op) -> MatchResult {
+auto PatternMatcher::findMatch(OperationInst *op) -> MatchResult {
// TODO: This is a completely trivial implementation, expand this in the
// future.
diff --git a/mlir/lib/IR/Statement.cpp b/mlir/lib/IR/Statement.cpp
index 8bff23d41ed..19457efa8c3 100644
--- a/mlir/lib/IR/Statement.cpp
+++ b/mlir/lib/IR/Statement.cpp
@@ -15,6 +15,7 @@
// limitations under the License.
// =============================================================================
+#include "AttributeListStorage.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinOps.h"
@@ -65,8 +66,8 @@ Statement::~Statement() {
/// Destroy this statement or one of its subclasses.
void Statement::destroy() {
switch (this->getKind()) {
- case Kind::Operation:
- cast<OperationStmt>(this)->destroy();
+ case Kind::OperationInst:
+ cast<OperationInst>(this)->destroy();
break;
case Kind::For:
delete cast<ForStmt>(this);
@@ -95,7 +96,7 @@ const Value *Statement::getOperand(unsigned idx) const {
// it is an induction variable, or it is a result of affine apply operation
// with dimension id arguments.
bool Value::isValidDim() const {
- if (auto *stmt = getDefiningStmt()) {
+ if (auto *stmt = getDefiningInst()) {
// Top level statement or constant operation is ok.
if (stmt->getParentStmt() == nullptr || stmt->isa<ConstantOp>())
return true;
@@ -113,7 +114,7 @@ bool Value::isValidDim() const {
// the top level, or it is a result of affine apply operation with symbol
// arguments.
bool Value::isValidSymbol() const {
- if (auto *stmt = getDefiningStmt()) {
+ if (auto *stmt = getDefiningInst()) {
// Top level statement or constant operation is ok.
if (stmt->getParentStmt() == nullptr || stmt->isa<ConstantOp>())
return true;
@@ -133,8 +134,8 @@ void Statement::setOperand(unsigned idx, Value *value) {
unsigned Statement::getNumOperands() const {
switch (getKind()) {
- case Kind::Operation:
- return cast<OperationStmt>(this)->getNumOperands();
+ case Kind::OperationInst:
+ return cast<OperationInst>(this)->getNumOperands();
case Kind::For:
return cast<ForStmt>(this)->getNumOperands();
case Kind::If:
@@ -144,8 +145,8 @@ unsigned Statement::getNumOperands() const {
MutableArrayRef<StmtOperand> Statement::getStmtOperands() {
switch (getKind()) {
- case Kind::Operation:
- return cast<OperationStmt>(this)->getStmtOperands();
+ case Kind::OperationInst:
+ return cast<OperationInst>(this)->getStmtOperands();
case Kind::For:
return cast<ForStmt>(this)->getStmtOperands();
case Kind::If:
@@ -177,7 +178,7 @@ bool Statement::emitError(const Twine &message) const {
// Returns whether the Statement is a terminator.
bool Statement::isTerminator() const {
- if (auto *op = dyn_cast<OperationStmt>(this))
+ if (auto *op = dyn_cast<OperationInst>(this))
return op->isTerminator();
return false;
}
@@ -264,11 +265,11 @@ void Statement::dropAllReferences() {
}
//===----------------------------------------------------------------------===//
-// OperationStmt
+// OperationInst
//===----------------------------------------------------------------------===//
-/// Create a new OperationStmt with the specific fields.
-OperationStmt *OperationStmt::create(Location location, OperationName name,
+/// Create a new OperationInst with the specific fields.
+OperationInst *OperationInst::create(Location location, OperationName name,
ArrayRef<Value *> operands,
ArrayRef<Type> resultTypes,
ArrayRef<NamedAttribute> attributes,
@@ -285,9 +286,9 @@ OperationStmt *OperationStmt::create(Location location, OperationName name,
resultTypes.size(), numSuccessors, numSuccessors, numOperands);
void *rawMem = malloc(byteSize);
- // Initialize the OperationStmt part of the statement.
+ // Initialize the OperationInst part of the statement.
auto stmt = ::new (rawMem)
- OperationStmt(location, name, numOperands, resultTypes.size(),
+ OperationInst(location, name, numOperands, resultTypes.size(),
numSuccessors, attributes, context);
// Initialize the results and operands.
@@ -355,15 +356,22 @@ OperationStmt *OperationStmt::create(Location location, OperationName name,
return stmt;
}
-OperationStmt::OperationStmt(Location location, OperationName name,
+OperationInst::OperationInst(Location location, OperationName name,
unsigned numOperands, unsigned numResults,
unsigned numSuccessors,
ArrayRef<NamedAttribute> attributes,
MLIRContext *context)
- : Operation(name, attributes, location, context), numOperands(numOperands),
- numResults(numResults), numSuccs(numSuccessors) {}
+ : Statement(Kind::OperationInst, location), numOperands(numOperands),
+ numResults(numResults), numSuccs(numSuccessors), name(name) {
+#ifndef NDEBUG
+ for (auto elt : attributes)
+ assert(elt.second != nullptr && "Attributes cannot have null entries");
+#endif
-OperationStmt::~OperationStmt() {
+ this->attrs = AttributeListStorage::get(attributes, context);
+}
+
+OperationInst::~OperationInst() {
// Explicitly run the destructors for the operands and results.
for (auto &operand : getStmtOperands())
operand.~StmtOperand();
@@ -377,13 +385,27 @@ OperationStmt::~OperationStmt() {
successor.~StmtBlockOperand();
}
-void OperationStmt::destroy() {
- this->~OperationStmt();
+/// Return true if there are no users of any results of this operation.
+bool OperationInst::use_empty() const {
+ for (auto *result : getResults())
+ if (!result->use_empty())
+ return false;
+ return true;
+}
+
+ArrayRef<NamedAttribute> OperationInst::getAttrs() const {
+ if (!attrs)
+ return {};
+ return attrs->getElements();
+}
+
+void OperationInst::destroy() {
+ this->~OperationInst();
free(this);
}
/// Return the context this operation is associated with.
-MLIRContext *OperationStmt::getContext() const {
+MLIRContext *OperationInst::getContext() const {
// If we have a result or operand type, that is a constant time way to get
// to the context.
if (getNumResults())
@@ -396,9 +418,9 @@ MLIRContext *OperationStmt::getContext() const {
return getFunction()->getContext();
}
-bool OperationStmt::isReturn() const { return isa<ReturnOp>(); }
+bool OperationInst::isReturn() const { return isa<ReturnOp>(); }
-void OperationStmt::setSuccessor(BasicBlock *block, unsigned index) {
+void OperationInst::setSuccessor(BasicBlock *block, unsigned index) {
assert(index < getNumSuccessors());
getBlockOperands()[index].set(block);
}
@@ -413,6 +435,96 @@ void OperationInst::eraseOperand(unsigned index) {
Operands[getNumOperands()].~StmtOperand();
}
+auto OperationInst::getSuccessorOperands(unsigned index) const
+ -> llvm::iterator_range<const_operand_iterator> {
+ assert(isTerminator() && "Only terminators have successors.");
+ unsigned succOperandIndex = getSuccessorOperandIndex(index);
+ return {const_operand_iterator(this, succOperandIndex),
+ const_operand_iterator(this, succOperandIndex +
+ getNumSuccessorOperands(index))};
+}
+auto OperationInst::getSuccessorOperands(unsigned index)
+ -> llvm::iterator_range<operand_iterator> {
+ assert(isTerminator() && "Only terminators have successors.");
+ unsigned succOperandIndex = getSuccessorOperandIndex(index);
+ return {operand_iterator(this, succOperandIndex),
+ operand_iterator(this,
+ succOperandIndex + getNumSuccessorOperands(index))};
+}
+
+/// If an attribute exists with the specified name, change it to the new
+/// value. Otherwise, add a new attribute with the specified name/value.
+void OperationInst::setAttr(Identifier name, Attribute value) {
+ assert(value && "attributes may never be null");
+ auto origAttrs = getAttrs();
+
+ SmallVector<NamedAttribute, 8> newAttrs(origAttrs.begin(), origAttrs.end());
+ auto *context = getContext();
+
+ // If we already have this attribute, replace it.
+ for (auto &elt : newAttrs)
+ if (elt.first == name) {
+ elt.second = value;
+ attrs = AttributeListStorage::get(newAttrs, context);
+ return;
+ }
+
+ // Otherwise, add it.
+ newAttrs.push_back({name, value});
+ attrs = AttributeListStorage::get(newAttrs, context);
+}
+
+/// Remove the attribute with the specified name if it exists. The return
+/// value indicates whether the attribute was present or not.
+auto OperationInst::removeAttr(Identifier name) -> RemoveResult {
+ auto origAttrs = getAttrs();
+ for (unsigned i = 0, e = origAttrs.size(); i != e; ++i) {
+ if (origAttrs[i].first == name) {
+ SmallVector<NamedAttribute, 8> newAttrs;
+ newAttrs.reserve(origAttrs.size() - 1);
+ newAttrs.append(origAttrs.begin(), origAttrs.begin() + i);
+ newAttrs.append(origAttrs.begin() + i + 1, origAttrs.end());
+ attrs = AttributeListStorage::get(newAttrs, getContext());
+ return RemoveResult::Removed;
+ }
+ }
+ return RemoveResult::NotFound;
+}
+
+/// Attempt to constant fold this operation with the specified constant
+/// operand values. If successful, this returns false and fills in the
+/// results vector. If not, this returns true and results is unspecified.
+bool OperationInst::constantFold(ArrayRef<Attribute> operands,
+ SmallVectorImpl<Attribute> &results) const {
+ if (auto *abstractOp = getAbstractOperation()) {
+ // If we have a registered operation definition matching this one, use it to
+ // try to constant fold the operation.
+ if (!abstractOp->constantFoldHook(llvm::cast<OperationInst>(this), operands,
+ results))
+ return false;
+
+ // Otherwise, fall back on the dialect hook to handle it.
+ return abstractOp->dialect.constantFoldHook(llvm::cast<OperationInst>(this),
+ operands, results);
+ }
+
+ // If this operation hasn't been registered or doesn't have abstract
+ // operation, fall back to a dialect which matches the prefix.
+ auto opName = getName().getStringRef();
+ if (auto *dialect = getContext()->getRegisteredDialect(opName)) {
+ return dialect->constantFoldHook(llvm::cast<OperationInst>(this), operands,
+ results);
+ }
+
+ return true;
+}
+
+/// Emit an error with the op name prefixed, like "'dim' op " which is
+/// convenient for verifiers.
+bool OperationInst::emitOpError(const Twine &message) const {
+ return emitError(Twine('\'') + getName().getStringRef() + "' op " + message);
+}
+
//===----------------------------------------------------------------------===//
// ForStmt
//===----------------------------------------------------------------------===//
@@ -625,7 +737,7 @@ Statement *Statement::clone(DenseMap<const Value *, Value *> &operandMap,
SmallVector<Value *, 8> operands;
SmallVector<StmtBlock *, 2> successors;
- if (auto *opStmt = dyn_cast<OperationStmt>(this)) {
+ if (auto *opStmt = dyn_cast<OperationInst>(this)) {
operands.reserve(getNumOperands() + opStmt->getNumSuccessors());
if (!opStmt->isTerminator()) {
@@ -653,8 +765,8 @@ Statement *Statement::clone(DenseMap<const Value *, Value *> &operandMap,
operands.push_back(nullptr);
// Remap the successors operands.
- for (auto &operand : opStmt->getSuccessorOperands(succ))
- operands.push_back(remapOperand(operand.get()));
+ for (auto *operand : opStmt->getSuccessorOperands(succ))
+ operands.push_back(remapOperand(operand));
}
}
@@ -662,7 +774,7 @@ Statement *Statement::clone(DenseMap<const Value *, Value *> &operandMap,
resultTypes.reserve(opStmt->getNumResults());
for (auto *result : opStmt->getResults())
resultTypes.push_back(result->getType());
- auto *newOp = OperationStmt::create(getLoc(), opStmt->getName(), operands,
+ auto *newOp = OperationInst::create(getLoc(), opStmt->getName(), operands,
resultTypes, opStmt->getAttrs(),
successors, context);
// Remember the mapping of any results.
diff --git a/mlir/lib/IR/StmtBlock.cpp b/mlir/lib/IR/StmtBlock.cpp
index a50861a3060..cfb09e6bf45 100644
--- a/mlir/lib/IR/StmtBlock.cpp
+++ b/mlir/lib/IR/StmtBlock.cpp
@@ -100,13 +100,13 @@ void StmtBlock::eraseArgument(unsigned index) {
// Terminator management
//===----------------------------------------------------------------------===//
-OperationStmt *StmtBlock::getTerminator() {
+OperationInst *StmtBlock::getTerminator() {
if (empty())
return nullptr;
// Check if the last instruction is a terminator.
auto &backInst = statements.back();
- auto *opStmt = dyn_cast<OperationStmt>(&backInst);
+ auto *opStmt = dyn_cast<OperationInst>(&backInst);
if (!opStmt || !opStmt->isTerminator())
return nullptr;
return opStmt;
diff --git a/mlir/lib/IR/Value.cpp b/mlir/lib/IR/Value.cpp
index db58e126e61..41a6d80e2a2 100644
--- a/mlir/lib/IR/Value.cpp
+++ b/mlir/lib/IR/Value.cpp
@@ -28,29 +28,13 @@ OperationInst *Value::getDefiningInst() {
return nullptr;
}
-/// If this value is the result of an OperationStmt, return the statement
-/// that defines it.
-OperationStmt *Value::getDefiningStmt() {
- if (auto *result = dyn_cast<StmtResult>(this))
- return result->getOwner();
- return nullptr;
-}
-
-Operation *Value::getDefiningOperation() {
- if (auto *inst = getDefiningInst())
- return inst;
- if (auto *stmt = getDefiningStmt())
- return stmt;
- return nullptr;
-}
-
-/// Return the function that this Valueis defined in.
+/// Return the function that this Value is defined in.
Function *Value::getFunction() {
switch (getKind()) {
case Value::Kind::BlockArgument:
return cast<BlockArgument>(this)->getFunction();
case Value::Kind::StmtResult:
- return getDefiningStmt()->getFunction();
+ return getDefiningInst()->getFunction();
case Value::Kind::ForStmt:
return cast<ForStmt>(this)->getFunction();
}
@@ -73,8 +57,8 @@ void IRObjectWithUseList::replaceAllUsesWith(IRObjectWithUseList *newValue) {
/// Return the context this operation is associated with.
MLIRContext *IROperandOwner::getContext() const {
switch (getKind()) {
- case Kind::OperationStmt:
- return cast<OperationStmt>(this)->getContext();
+ case Kind::OperationInst:
+ return cast<OperationInst>(this)->getContext();
case Kind::ForStmt:
return cast<ForStmt>(this)->getContext();
case Kind::IfStmt:
diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index 06495eb81ab..35891f5784b 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -103,7 +103,7 @@ private:
namespace {
using CreateOperationFunction =
- std::function<Operation *(const OperationState &)>;
+ std::function<OperationInst *(const OperationState &)>;
/// This class implement support for parsing global entities like types and
/// shared entities like SSA names. It is intended to be subclassed by
@@ -1915,8 +1915,10 @@ public:
// Operations
ParseResult parseOperation(const CreateOperationFunction &createOpFunc);
- Operation *parseVerboseOperation(const CreateOperationFunction &createOpFunc);
- Operation *parseCustomOperation(const CreateOperationFunction &createOpFunc);
+ OperationInst *
+ parseVerboseOperation(const CreateOperationFunction &createOpFunc);
+ OperationInst *
+ parseCustomOperation(const CreateOperationFunction &createOpFunc);
/// Parse a single operation successor and it's operand list.
virtual bool parseSuccessorAndUseList(BasicBlock *&dest,
@@ -2184,7 +2186,7 @@ FunctionParser::parseOperation(const CreateOperationFunction &createOpFunc) {
return ParseFailure;
}
- Operation *op;
+ OperationInst *op;
if (getToken().is(Token::bare_identifier) || getToken().isKeyword())
op = parseCustomOperation(createOpFunc);
else if (getToken().is(Token::string))
@@ -2220,7 +2222,7 @@ FunctionParser::parseOperation(const CreateOperationFunction &createOpFunc) {
return ParseSuccess;
}
-Operation *FunctionParser::parseVerboseOperation(
+OperationInst *FunctionParser::parseVerboseOperation(
const CreateOperationFunction &createOpFunc) {
// Get location information for the operation.
@@ -2516,7 +2518,7 @@ private:
};
} // end anonymous namespace.
-Operation *FunctionParser::parseCustomOperation(
+OperationInst *FunctionParser::parseCustomOperation(
const CreateOperationFunction &createOpFunc) {
auto opLoc = getToken().getLoc();
auto opName = getTokenSpelling();
@@ -2746,7 +2748,7 @@ ParseResult CFGFunctionParser::parseBasicBlock() {
// into.
builder.setInsertionPointToEnd(block);
- auto createOpFunc = [&](const OperationState &result) -> Operation * {
+ auto createOpFunc = [&](const OperationState &result) -> OperationInst * {
return builder.createOperation(result);
};
@@ -3149,7 +3151,7 @@ ParseResult MLFunctionParser::parseElseClause(StmtBlock *elseClause) {
/// Parse a list of statements ending with `return` or `}`
///
ParseResult MLFunctionParser::parseStatements(StmtBlock *block) {
- auto createOpFunc = [&](const OperationState &state) -> Operation * {
+ auto createOpFunc = [&](const OperationState &state) -> OperationInst * {
return builder.createOperation(state);
};
diff --git a/mlir/lib/StandardOps/StandardOps.cpp b/mlir/lib/StandardOps/StandardOps.cpp
index 7611c6e741b..19a4c8d1afe 100644
--- a/mlir/lib/StandardOps/StandardOps.cpp
+++ b/mlir/lib/StandardOps/StandardOps.cpp
@@ -56,7 +56,7 @@ struct MemRefCastFolder : public RewritePattern {
MemRefCastFolder(StringRef rootOpName, MLIRContext *context)
: RewritePattern(rootOpName, 1, context) {}
- PatternMatchResult match(Operation *op) const override {
+ PatternMatchResult match(OperationInst *op) const override {
for (auto *operand : op->getOperands())
if (matchPattern(operand, m_Op<MemRefCastOp>()))
return matchSuccess();
@@ -64,9 +64,9 @@ struct MemRefCastFolder : public RewritePattern {
return matchFailure();
}
- void rewrite(Operation *op, PatternRewriter &rewriter) const override {
+ void rewrite(OperationInst *op, PatternRewriter &rewriter) const override {
for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i)
- if (auto *memref = op->getOperand(i)->getDefiningOperation())
+ if (auto *memref = op->getOperand(i)->getDefiningInst())
if (auto cast = memref->dyn_cast<MemRefCastOp>())
op->setOperand(i, cast->getOperand());
rewriter.updatedRootInPlace(op);
@@ -122,7 +122,7 @@ struct SimplifyAddX0 : public RewritePattern {
SimplifyAddX0(MLIRContext *context)
: RewritePattern(AddIOp::getOperationName(), 1, context) {}
- PatternMatchResult match(Operation *op) const override {
+ PatternMatchResult match(OperationInst *op) const override {
auto addi = op->cast<AddIOp>();
if (matchPattern(addi->getOperand(1), m_Zero()))
@@ -130,7 +130,7 @@ struct SimplifyAddX0 : public RewritePattern {
return matchFailure();
}
- void rewrite(Operation *op, PatternRewriter &rewriter) const override {
+ void rewrite(OperationInst *op, PatternRewriter &rewriter) const override {
rewriter.replaceOp(op, op->getOperand(0));
}
};
@@ -228,7 +228,7 @@ struct SimplifyAllocConst : public RewritePattern {
SimplifyAllocConst(MLIRContext *context)
: RewritePattern(AllocOp::getOperationName(), 1, context) {}
- PatternMatchResult match(Operation *op) const override {
+ PatternMatchResult match(OperationInst *op) const override {
auto alloc = op->cast<AllocOp>();
// Check to see if any dimensions operands are constants. If so, we can
@@ -239,7 +239,7 @@ struct SimplifyAllocConst : public RewritePattern {
return matchFailure();
}
- void rewrite(Operation *op, PatternRewriter &rewriter) const override {
+ void rewrite(OperationInst *op, PatternRewriter &rewriter) const override {
auto allocOp = op->cast<AllocOp>();
auto memrefType = allocOp->getType();
@@ -258,7 +258,7 @@ struct SimplifyAllocConst : public RewritePattern {
newShapeConstants.push_back(dimSize);
continue;
}
- auto *defOp = allocOp->getOperand(dynamicDimPos)->getDefiningOperation();
+ auto *defOp = allocOp->getOperand(dynamicDimPos)->getDefiningInst();
OpPointer<ConstantIndexOp> constantIndexOp;
if (defOp && (constantIndexOp = defOp->dyn_cast<ConstantIndexOp>())) {
// Dynamic shape dimension will be folded.
@@ -1105,7 +1105,7 @@ struct SimplifyMulX1 : public RewritePattern {
SimplifyMulX1(MLIRContext *context)
: RewritePattern(MulIOp::getOperationName(), 1, context) {}
- PatternMatchResult match(Operation *op) const override {
+ PatternMatchResult match(OperationInst *op) const override {
auto muli = op->cast<MulIOp>();
if (matchPattern(muli->getOperand(1), m_One()))
@@ -1113,7 +1113,7 @@ struct SimplifyMulX1 : public RewritePattern {
return matchFailure();
}
- void rewrite(Operation *op, PatternRewriter &rewriter) const override {
+ void rewrite(OperationInst *op, PatternRewriter &rewriter) const override {
rewriter.replaceOp(op, op->getOperand(0));
}
};
@@ -1308,14 +1308,14 @@ struct SimplifyXMinusX : public RewritePattern {
SimplifyXMinusX(MLIRContext *context)
: RewritePattern(SubIOp::getOperationName(), 1, context) {}
- PatternMatchResult match(Operation *op) const override {
+ PatternMatchResult match(OperationInst *op) const override {
auto subi = op->cast<SubIOp>();
if (subi->getOperand(0) == subi->getOperand(1))
return matchSuccess();
return matchFailure();
}
- void rewrite(Operation *op, PatternRewriter &rewriter) const override {
+ void rewrite(OperationInst *op, PatternRewriter &rewriter) const override {
auto subi = op->cast<SubIOp>();
auto result =
rewriter.create<ConstantIntOp>(op->getLoc(), 0, subi->getType());
diff --git a/mlir/lib/SuperVectorOps/SuperVectorOps.cpp b/mlir/lib/SuperVectorOps/SuperVectorOps.cpp
index 02b4c4674ab..e4243a6de25 100644
--- a/mlir/lib/SuperVectorOps/SuperVectorOps.cpp
+++ b/mlir/lib/SuperVectorOps/SuperVectorOps.cpp
@@ -86,14 +86,14 @@ void VectorTransferReadOp::build(Builder *builder, OperationState *result,
result->addTypes(vectorType);
}
-llvm::iterator_range<Operation::operand_iterator>
+llvm::iterator_range<OperationInst::operand_iterator>
VectorTransferReadOp::getIndices() {
auto begin = getOperation()->operand_begin() + Offsets::FirstIndexOffset;
auto end = begin + getMemRefType().getRank();
return {begin, end};
}
-llvm::iterator_range<Operation::const_operand_iterator>
+llvm::iterator_range<OperationInst::const_operand_iterator>
VectorTransferReadOp::getIndices() const {
auto begin = getOperation()->operand_begin() + Offsets::FirstIndexOffset;
auto end = begin + getMemRefType().getRank();
@@ -303,14 +303,14 @@ void VectorTransferWriteOp::build(Builder *builder, OperationState *result,
builder->getAffineMapAttr(permutationMap));
}
-llvm::iterator_range<Operation::operand_iterator>
+llvm::iterator_range<OperationInst::operand_iterator>
VectorTransferWriteOp::getIndices() {
auto begin = getOperation()->operand_begin() + Offsets::FirstIndexOffset;
auto end = begin + getMemRefType().getRank();
return {begin, end};
}
-llvm::iterator_range<Operation::const_operand_iterator>
+llvm::iterator_range<OperationInst::const_operand_iterator>
VectorTransferWriteOp::getIndices() const {
auto begin = getOperation()->operand_begin() + Offsets::FirstIndexOffset;
auto end = begin + getMemRefType().getRank();
diff --git a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp
index a4d474dc24a..713aa0b1791 100644
--- a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp
+++ b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp
@@ -111,8 +111,8 @@ private:
/// descriptor and get the pointer to the element indexed by the linearized
/// subscript. Return nullptr on errors.
llvm::Value *emitMemRefElementAccess(
- const Value *memRef, const Operation &op,
- llvm::iterator_range<Operation::const_operand_iterator> opIndices);
+ const Value *memRef, const OperationInst &op,
+ llvm::iterator_range<OperationInst::const_operand_iterator> opIndices);
/// Emit LLVM IR corresponding to the given Alloc `op`. In particular, create
/// a Value for the MemRef descriptor, store any dynamic sizes passed to
@@ -307,7 +307,7 @@ ModuleLowerer::linearizeSubscripts(ArrayRef<llvm::Value *> indices,
// the location of `op` and return true. Return false if the type is supported.
// TODO(zinenko): this function should disappear when the conversion fully
// supports MemRefs.
-static bool checkSupportedMemRefType(MemRefType type, const Operation &op) {
+static bool checkSupportedMemRefType(MemRefType type, const OperationInst &op) {
if (!type.getAffineMaps().empty())
return op.emitError("NYI: memrefs with affine maps");
if (type.getMemorySpace() != 0)
@@ -316,8 +316,8 @@ static bool checkSupportedMemRefType(MemRefType type, const Operation &op) {
}
llvm::Value *ModuleLowerer::emitMemRefElementAccess(
- const Value *memRef, const Operation &op,
- llvm::iterator_range<Operation::const_operand_iterator> opIndices) {
+ const Value *memRef, const OperationInst &op,
+ llvm::iterator_range<OperationInst::const_operand_iterator> opIndices) {
auto type = memRef->getType().dyn_cast<MemRefType>();
assert(type && "expected memRef value to have a MemRef type");
if (checkSupportedMemRefType(type, op))
@@ -425,7 +425,7 @@ ModuleLowerer::emitMemRefDealloc(ConstOpPointer<DeallocOp> deallocOp) {
// This forcibly recreates the APFloat with IEEESingle semantics to make sure
// LLVM constructs a `float` constant.
static llvm::ConstantFP *getFloatConstant(APFloat APvalue,
- const Operation &inst,
+ const OperationInst &inst,
llvm::LLVMContext *context) {
bool unused;
APFloat::opStatus status = APvalue.convert(
diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp
index 575ae2e1c9b..4b198589e2c 100644
--- a/mlir/lib/Transforms/CSE.cpp
+++ b/mlir/lib/Transforms/CSE.cpp
@@ -50,10 +50,10 @@ struct CSE : public FunctionPass {
};
// TODO(riverriddle) Handle commutative operations.
-struct SimpleOperationInfo : public llvm::DenseMapInfo<Operation *> {
- static unsigned getHashValue(const Operation *op) {
+struct SimpleOperationInfo : public llvm::DenseMapInfo<OperationInst *> {
+ static unsigned getHashValue(const OperationInst *op) {
// Hash the operations based upon their:
- // - Operation Name
+ // - OperationInst Name
// - Attributes
// - Result Types
// - Operands
@@ -62,7 +62,7 @@ struct SimpleOperationInfo : public llvm::DenseMapInfo<Operation *> {
hash_combine_range(op->result_type_begin(), op->result_type_end()),
hash_combine_range(op->operand_begin(), op->operand_end()));
}
- static bool isEqual(const Operation *lhs, const Operation *rhs) {
+ static bool isEqual(const OperationInst *lhs, const OperationInst *rhs) {
if (lhs == rhs)
return true;
if (lhs == getTombstoneKey() || lhs == getEmptyKey() ||
@@ -93,8 +93,8 @@ struct SimpleOperationInfo : public llvm::DenseMapInfo<Operation *> {
struct CSEImpl {
using AllocatorTy = llvm::RecyclingAllocator<
llvm::BumpPtrAllocator,
- llvm::ScopedHashTableVal<Operation *, Operation *>>;
- using ScopedMapTy = llvm::ScopedHashTable<Operation *, Operation *,
+ llvm::ScopedHashTableVal<OperationInst *, OperationInst *>>;
+ using ScopedMapTy = llvm::ScopedHashTable<OperationInst *, OperationInst *,
SimpleOperationInfo, AllocatorTy>;
/// Erase any operations that were marked as dead during simplification.
@@ -104,7 +104,7 @@ struct CSEImpl {
}
/// Attempt to eliminate a redundant operation.
- void simplifyOperation(Operation *op) {
+ void simplifyOperation(OperationInst *op) {
// TODO(riverriddle) We currently only eliminate non side-effecting
// operations.
if (!op->hasNoSideEffect())
@@ -141,7 +141,7 @@ struct CSEImpl {
ScopedMapTy knownValues;
/// Operations marked as dead and to be erased.
- std::vector<Operation *> opsToErase;
+ std::vector<OperationInst *> opsToErase;
};
/// Common sub-expression elimination for CFG functions.
@@ -224,7 +224,7 @@ struct MLCSE : public CSEImpl, StmtWalker<MLCSE> {
StmtWalker<MLCSE>::walk(Start, End);
}
- void visitOperationStmt(OperationStmt *stmt) { simplifyOperation(stmt); }
+ void visitOperationInst(OperationInst *stmt) { simplifyOperation(stmt); }
};
} // end anonymous namespace
diff --git a/mlir/lib/Transforms/ComposeAffineMaps.cpp b/mlir/lib/Transforms/ComposeAffineMaps.cpp
index 84507b91703..365533561f9 100644
--- a/mlir/lib/Transforms/ComposeAffineMaps.cpp
+++ b/mlir/lib/Transforms/ComposeAffineMaps.cpp
@@ -42,12 +42,12 @@ namespace {
// with no remaining uses are collected and erased after the walk.
// TODO(andydavis) Remove this when Chris adds instruction combiner pass.
struct ComposeAffineMaps : public FunctionPass, StmtWalker<ComposeAffineMaps> {
- std::vector<OperationStmt *> affineApplyOpsToErase;
+ std::vector<OperationInst *> affineApplyOpsToErase;
explicit ComposeAffineMaps() : FunctionPass(&ComposeAffineMaps::passID) {}
using StmtListType = llvm::iplist<Statement>;
void walk(StmtListType::iterator Start, StmtListType::iterator End);
- void visitOperationStmt(OperationStmt *stmt);
+ void visitOperationInst(OperationInst *stmt);
PassResult runOnMLFunction(MLFunction *f) override;
using StmtWalker<ComposeAffineMaps>::walk;
@@ -72,7 +72,7 @@ void ComposeAffineMaps::walk(StmtListType::iterator Start,
}
}
-void ComposeAffineMaps::visitOperationStmt(OperationStmt *opStmt) {
+void ComposeAffineMaps::visitOperationInst(OperationInst *opStmt) {
if (auto affineApplyOp = opStmt->dyn_cast<AffineApplyOp>()) {
forwardSubstitute(affineApplyOp);
bool allUsesEmpty = true;
diff --git a/mlir/lib/Transforms/ConstantFold.cpp b/mlir/lib/Transforms/ConstantFold.cpp
index b6b1dec7b17..a83e625c240 100644
--- a/mlir/lib/Transforms/ConstantFold.cpp
+++ b/mlir/lib/Transforms/ConstantFold.cpp
@@ -31,13 +31,14 @@ struct ConstantFold : public FunctionPass, StmtWalker<ConstantFold> {
// All constants in the function post folding.
SmallVector<Value *, 8> existingConstants;
- // Operation statements that were folded and that need to be erased.
- std::vector<OperationStmt *> opStmtsToErase;
+ // Operations that were folded and that need to be erased.
+ std::vector<OperationInst *> opStmtsToErase;
using ConstantFactoryType = std::function<Value *(Attribute, Type)>;
- bool foldOperation(Operation *op, SmallVectorImpl<Value *> &existingConstants,
+ bool foldOperation(OperationInst *op,
+ SmallVectorImpl<Value *> &existingConstants,
ConstantFactoryType constantFactory);
- void visitOperationStmt(OperationStmt *stmt);
+ void visitOperationInst(OperationInst *stmt);
void visitForStmt(ForStmt *stmt);
PassResult runOnCFGFunction(CFGFunction *f) override;
PassResult runOnMLFunction(MLFunction *f) override;
@@ -52,7 +53,7 @@ char ConstantFold::passID = 0;
/// constants are found, we keep track of them in the existingConstants list.
///
/// This returns false if the operation was successfully folded.
-bool ConstantFold::foldOperation(Operation *op,
+bool ConstantFold::foldOperation(OperationInst *op,
SmallVectorImpl<Value *> &existingConstants,
ConstantFactoryType constantFactory) {
// If this operation is already a constant, just remember it for cleanup
@@ -67,7 +68,7 @@ bool ConstantFold::foldOperation(Operation *op,
SmallVector<Attribute, 8> operandConstants;
for (auto *operand : op->getOperands()) {
Attribute operandCst = nullptr;
- if (auto *operandOp = operand->getDefiningOperation()) {
+ if (auto *operandOp = operand->getDefiningInst()) {
if (auto operandConstantOp = operandOp->dyn_cast<ConstantOp>())
operandCst = operandConstantOp->getValue();
}
@@ -138,8 +139,8 @@ PassResult ConstantFold::runOnCFGFunction(CFGFunction *f) {
return success();
}
-// Override the walker's operation statement visit for constant folding.
-void ConstantFold::visitOperationStmt(OperationStmt *stmt) {
+// Override the walker's operation visiter for constant folding.
+void ConstantFold::visitOperationInst(OperationInst *stmt) {
auto constantFactory = [&](Attribute value, Type type) -> Value * {
FuncBuilder builder(stmt);
return builder.create<ConstantOp>(stmt->getLoc(), value, type);
@@ -172,7 +173,7 @@ PassResult ConstantFold::runOnMLFunction(MLFunction *f) {
// around dead constants. Check for them now and remove them.
for (auto *cst : existingConstants) {
if (cst->use_empty())
- cst->getDefiningStmt()->erase();
+ cst->getDefiningInst()->erase();
}
return success();
diff --git a/mlir/lib/Transforms/ConvertToCFG.cpp b/mlir/lib/Transforms/ConvertToCFG.cpp
index fefe9f700c4..ca158a17e92 100644
--- a/mlir/lib/Transforms/ConvertToCFG.cpp
+++ b/mlir/lib/Transforms/ConvertToCFG.cpp
@@ -47,14 +47,14 @@ public:
void visitForStmt(ForStmt *forStmt);
void visitIfStmt(IfStmt *ifStmt);
- void visitOperationStmt(OperationStmt *opStmt);
+ void visitOperationInst(OperationInst *opStmt);
private:
Value *getConstantIndexValue(int64_t value);
void visitStmtBlock(StmtBlock *stmtBlock);
Value *buildMinMaxReductionSeq(
Location loc, CmpIPredicate predicate,
- llvm::iterator_range<Operation::result_iterator> values);
+ llvm::iterator_range<OperationInst::result_iterator> values);
CFGFunction *cfgFunc;
FuncBuilder builder;
@@ -64,7 +64,7 @@ private:
};
} // end anonymous namespace
-// Return a vector of OperationStmt's arguments as Values. For each
+// Return a vector of OperationInst's arguments as Values. For each
// statement operands, represented as Value, lookup its Value conterpart in
// the valueRemapping table.
static llvm::SmallVector<mlir::Value *, 4>
@@ -84,7 +84,7 @@ operandsAs(Statement *opStmt,
// remains the same but the values must be updated to be Values. Update the
// mapping Value->Value as the conversion is performed. The operation
// instruction is appended to current block (end of SESE region).
-void FunctionConverter::visitOperationStmt(OperationStmt *opStmt) {
+void FunctionConverter::visitOperationInst(OperationInst *opStmt) {
// Set up basic operation state (context, name, operands).
OperationState state(cfgFunc->getContext(), opStmt->getLoc(),
opStmt->getName());
@@ -136,7 +136,7 @@ void FunctionConverter::visitStmtBlock(StmtBlock *stmtBlock) {
// recognize as a reduction by the subsequent passes.
Value *FunctionConverter::buildMinMaxReductionSeq(
Location loc, CmpIPredicate predicate,
- llvm::iterator_range<Operation::result_iterator> values) {
+ llvm::iterator_range<OperationInst::result_iterator> values) {
assert(!llvm::empty(values) && "empty min/max chain");
auto valueIt = values.begin();
@@ -600,7 +600,7 @@ void ModuleConverter::replaceReferences() {
// operation "op" and containing an MLFunction-typed value with the result of
// converting "func" to a CFGFunction.
static inline void replaceMLFunctionAttr(
- Operation &op, Identifier name, const Function *func,
+ OperationInst &op, Identifier name, const Function *func,
const llvm::DenseMap<MLFunction *, CFGFunction *> &generatedFuncs) {
if (!func->isML())
return;
diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp
index cc2ca32421b..ed184dc9421 100644
--- a/mlir/lib/Transforms/DmaGeneration.cpp
+++ b/mlir/lib/Transforms/DmaGeneration.cpp
@@ -67,7 +67,7 @@ struct DmaGeneration : public FunctionPass, StmtWalker<DmaGeneration> {
PassResult runOnMLFunction(MLFunction *f) override;
void runOnForStmt(ForStmt *forStmt);
- void visitOperationStmt(OperationStmt *opStmt);
+ void visitOperationInst(OperationInst *opStmt);
bool generateDma(const MemRefRegion &region, ForStmt *forStmt,
uint64_t *sizeInBytes);
@@ -108,7 +108,7 @@ FunctionPass *mlir::createDmaGenerationPass(unsigned slowMemorySpace,
// Gather regions to promote to buffers in faster memory space.
// TODO(bondhugula): handle store op's; only load's handled for now.
-void DmaGeneration::visitOperationStmt(OperationStmt *opStmt) {
+void DmaGeneration::visitOperationInst(OperationInst *opStmt) {
if (auto loadOp = opStmt->dyn_cast<LoadOp>()) {
if (loadOp->getMemRefType().getMemorySpace() != slowMemorySpace)
return;
diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp
index c86eec3d276..67b36cfda30 100644
--- a/mlir/lib/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Transforms/LoopFusion.cpp
@@ -80,7 +80,7 @@ char LoopFusion::passID = 0;
FunctionPass *mlir::createLoopFusionPass() { return new LoopFusion; }
-static void getSingleMemRefAccess(OperationStmt *loadOrStoreOpStmt,
+static void getSingleMemRefAccess(OperationInst *loadOrStoreOpStmt,
MemRefAccess *access) {
if (auto loadOp = loadOrStoreOpStmt->dyn_cast<LoadOp>()) {
access->memref = loadOp->getMemRef();
@@ -112,8 +112,8 @@ struct FusionCandidate {
MemRefAccess dstAccess;
};
-static FusionCandidate buildFusionCandidate(OperationStmt *srcStoreOpStmt,
- OperationStmt *dstLoadOpStmt) {
+static FusionCandidate buildFusionCandidate(OperationInst *srcStoreOpStmt,
+ OperationInst *dstLoadOpStmt) {
FusionCandidate candidate;
// Get store access for src loop nest.
getSingleMemRefAccess(srcStoreOpStmt, &candidate.srcAccess);
@@ -123,7 +123,7 @@ static FusionCandidate buildFusionCandidate(OperationStmt *srcStoreOpStmt,
}
// Returns the loop depth of the loop nest surrounding 'opStmt'.
-static unsigned getLoopDepth(OperationStmt *opStmt) {
+static unsigned getLoopDepth(OperationInst *opStmt) {
unsigned loopDepth = 0;
auto *currStmt = opStmt->getParentStmt();
ForStmt *currForStmt;
@@ -141,15 +141,15 @@ namespace {
class LoopNestStateCollector : public StmtWalker<LoopNestStateCollector> {
public:
SmallVector<ForStmt *, 4> forStmts;
- SmallVector<OperationStmt *, 4> loadOpStmts;
- SmallVector<OperationStmt *, 4> storeOpStmts;
+ SmallVector<OperationInst *, 4> loadOpStmts;
+ SmallVector<OperationInst *, 4> storeOpStmts;
bool hasIfStmt = false;
void visitForStmt(ForStmt *forStmt) { forStmts.push_back(forStmt); }
void visitIfStmt(IfStmt *ifStmt) { hasIfStmt = true; }
- void visitOperationStmt(OperationStmt *opStmt) {
+ void visitOperationInst(OperationInst *opStmt) {
if (opStmt->isa<LoadOp>())
loadOpStmts.push_back(opStmt);
if (opStmt->isa<StoreOp>())
@@ -171,10 +171,10 @@ public:
unsigned id;
// The top-level statment which is (or contains) loads/stores.
Statement *stmt;
- // List of load op stmts.
- SmallVector<OperationStmt *, 4> loads;
+ // List of load operations.
+ SmallVector<OperationInst *, 4> loads;
// List of store op stmts.
- SmallVector<OperationStmt *, 4> stores;
+ SmallVector<OperationInst *, 4> stores;
Node(unsigned id, Statement *stmt) : id(id), stmt(stmt) {}
// Returns the load op count for 'memref'.
@@ -312,8 +312,8 @@ public:
}
// Adds ops in 'loads' and 'stores' to node at 'id'.
- void addToNode(unsigned id, const SmallVectorImpl<OperationStmt *> &loads,
- const SmallVectorImpl<OperationStmt *> &stores) {
+ void addToNode(unsigned id, const SmallVectorImpl<OperationInst *> &loads,
+ const SmallVectorImpl<OperationInst *> &stores) {
Node *node = getNode(id);
for (auto *loadOpStmt : loads)
node->loads.push_back(loadOpStmt);
@@ -370,7 +370,7 @@ bool MemRefDependenceGraph::init(MLFunction *f) {
}
nodes.insert({node.id, node});
}
- if (auto *opStmt = dyn_cast<OperationStmt>(&stmt)) {
+ if (auto *opStmt = dyn_cast<OperationInst>(&stmt)) {
if (auto loadOp = opStmt->dyn_cast<LoadOp>()) {
// Create graph node for top-level load op.
Node node(id++, &stmt);
@@ -474,7 +474,7 @@ public:
if (!isa<ForStmt>(dstNode->stmt))
continue;
- SmallVector<OperationStmt *, 4> loads = dstNode->loads;
+ SmallVector<OperationInst *, 4> loads = dstNode->loads;
while (!loads.empty()) {
auto *dstLoadOpStmt = loads.pop_back_val();
auto *memref = dstLoadOpStmt->cast<LoadOp>()->getMemRef();
diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp
index 183613a2f69..0a3dd65d1f4 100644
--- a/mlir/lib/Transforms/LoopUnroll.cpp
+++ b/mlir/lib/Transforms/LoopUnroll.cpp
@@ -120,7 +120,7 @@ PassResult LoopUnroll::runOnMLFunction(MLFunction *f) {
return hasInnerLoops;
}
- bool visitOperationStmt(OperationStmt *opStmt) { return false; }
+ bool visitOperationInst(OperationInst *opStmt) { return false; }
// FIXME: can't use base class method for this because that in turn would
// need to use the derived class method above. CRTP doesn't allow it, and
diff --git a/mlir/lib/Transforms/LowerVectorTransfers.cpp b/mlir/lib/Transforms/LowerVectorTransfers.cpp
index fd23c341903..e2fd8b66e34 100644
--- a/mlir/lib/Transforms/LowerVectorTransfers.cpp
+++ b/mlir/lib/Transforms/LowerVectorTransfers.cpp
@@ -185,7 +185,7 @@ static void rewriteAsLoops(VectorTransferOpTy *transfer,
// case of GPUs.
llvm::SmallVector<Value *, 1> newResults = {};
if (std::is_same<VectorTransferOpTy, VectorTransferReadOp>::value) {
- b.setInsertionPoint(cast<OperationStmt>(transfer->getOperation()));
+ b.setInsertionPoint(cast<OperationInst>(transfer->getOperation()));
auto *vector = b.create<LoadOp>(transfer->getLoc(), vecView->getResult(),
ArrayRef<Value *>{state->zero})
->getResult();
@@ -193,7 +193,7 @@ static void rewriteAsLoops(VectorTransferOpTy *transfer,
}
// 6. Free the local buffer.
- b.setInsertionPoint(cast<OperationStmt>(transfer->getOperation()));
+ b.setInsertionPoint(cast<OperationInst>(transfer->getOperation()));
b.create<DeallocOp>(transfer->getLoc(), tmpScalarAlloc);
// 7. It is now safe to erase the statement.
@@ -207,13 +207,14 @@ public:
explicit VectorTransferExpander(MLIRContext *context)
: MLLoweringPattern(VectorTransferOpTy::getOperationName(), 1, context) {}
- PatternMatchResult match(Operation *op) const override {
+ PatternMatchResult match(OperationInst *op) const override {
if (m_Op<VectorTransferOpTy>().match(op))
return matchSuccess();
return matchFailure();
}
- void rewriteOpStmt(Operation *op, MLFuncGlobalLoweringState *funcWiseState,
+ void rewriteOpStmt(OperationInst *op,
+ MLFuncGlobalLoweringState *funcWiseState,
std::unique_ptr<PatternState> opState,
MLFuncLoweringRewriter *rewriter) const override {
rewriteAsLoops(&*op->dyn_cast<VectorTransferOpTy>(), rewriter,
diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp
index b4d91b2506c..6f033710798 100644
--- a/mlir/lib/Transforms/MaterializeVectors.cpp
+++ b/mlir/lib/Transforms/MaterializeVectors.cpp
@@ -246,8 +246,8 @@ static SmallVector<unsigned, 8> delinearize(unsigned linearIndex,
return res;
}
-static OperationStmt *
-instantiate(FuncBuilder *b, OperationStmt *opStmt, VectorType hwVectorType,
+static OperationInst *
+instantiate(FuncBuilder *b, OperationInst *opStmt, VectorType hwVectorType,
DenseMap<const Value *, Value *> *substitutionsMap);
/// Not all Values belong to a program slice scoped within the immediately
@@ -263,7 +263,7 @@ static Value *substitute(Value *v, VectorType hwVectorType,
DenseMap<const Value *, Value *> *substitutionsMap) {
auto it = substitutionsMap->find(v);
if (it == substitutionsMap->end()) {
- auto *opStmt = cast<OperationStmt>(v->getDefiningOperation());
+ auto *opStmt = cast<OperationInst>(v->getDefiningInst());
if (opStmt->isa<ConstantOp>()) {
FuncBuilder b(opStmt);
auto *inst = instantiate(&b, opStmt, hwVectorType, substitutionsMap);
@@ -272,7 +272,7 @@ static Value *substitute(Value *v, VectorType hwVectorType,
assert(res.second && "Insertion failed");
return res.first->second;
}
- v->getDefiningOperation()->emitError("Missing substitution");
+ v->getDefiningInst()->emitError("Missing substitution");
return nullptr;
}
return it->second;
@@ -384,7 +384,7 @@ reindexAffineIndices(FuncBuilder *b, VectorType hwVectorType,
/// - constant splat is replaced by constant splat of `hwVectorType`.
/// TODO(ntv): add more substitutions on a per-need basis.
static SmallVector<NamedAttribute, 1>
-materializeAttributes(OperationStmt *opStmt, VectorType hwVectorType) {
+materializeAttributes(OperationInst *opStmt, VectorType hwVectorType) {
SmallVector<NamedAttribute, 1> res;
for (auto a : opStmt->getAttrs()) {
if (auto splat = a.second.dyn_cast<SplatElementsAttr>()) {
@@ -404,8 +404,8 @@ materializeAttributes(OperationStmt *opStmt, VectorType hwVectorType) {
/// substitutionsMap.
///
/// If the underlying substitution fails, this fails too and returns nullptr.
-static OperationStmt *
-instantiate(FuncBuilder *b, OperationStmt *opStmt, VectorType hwVectorType,
+static OperationInst *
+instantiate(FuncBuilder *b, OperationInst *opStmt, VectorType hwVectorType,
DenseMap<const Value *, Value *> *substitutionsMap) {
assert(!opStmt->isa<VectorTransferReadOp>() &&
"Should call the function specialized for VectorTransferReadOp");
@@ -475,7 +475,7 @@ static AffineMap projectedPermutationMap(VectorTransferOpTy *transfer,
/// `hwVectorType` int the covering of the super-vector type. For a more
/// detailed description of the problem, see the description of
/// reindexAffineIndices.
-static OperationStmt *
+static OperationInst *
instantiate(FuncBuilder *b, VectorTransferReadOp *read, VectorType hwVectorType,
ArrayRef<unsigned> hwVectorInstance,
DenseMap<const Value *, Value *> *substitutionsMap) {
@@ -486,7 +486,7 @@ instantiate(FuncBuilder *b, VectorTransferReadOp *read, VectorType hwVectorType,
auto cloned = b->create<VectorTransferReadOp>(
read->getLoc(), hwVectorType, read->getMemRef(), affineIndices,
projectedPermutationMap(read, hwVectorType), read->getPaddingValue());
- return cast<OperationStmt>(cloned->getOperation());
+ return cast<OperationInst>(cloned->getOperation());
}
/// Creates an instantiated version of `write` for the instance of
@@ -495,7 +495,7 @@ instantiate(FuncBuilder *b, VectorTransferReadOp *read, VectorType hwVectorType,
/// `hwVectorType` int the covering of th3e super-vector type. For a more
/// detailed description of the problem, see the description of
/// reindexAffineIndices.
-static OperationStmt *
+static OperationInst *
instantiate(FuncBuilder *b, VectorTransferWriteOp *write,
VectorType hwVectorType, ArrayRef<unsigned> hwVectorInstance,
DenseMap<const Value *, Value *> *substitutionsMap) {
@@ -508,7 +508,7 @@ instantiate(FuncBuilder *b, VectorTransferWriteOp *write,
substitute(write->getVector(), hwVectorType, substitutionsMap),
write->getMemRef(), affineIndices,
projectedPermutationMap(write, hwVectorType));
- return cast<OperationStmt>(cloned->getOperation());
+ return cast<OperationInst>(cloned->getOperation());
}
/// Returns `true` if stmt instance is properly cloned and inserted, false
@@ -544,7 +544,7 @@ static bool instantiateMaterialization(Statement *stmt,
// Create a builder here for unroll-and-jam effects.
FuncBuilder b(stmt);
- auto *opStmt = cast<OperationStmt>(stmt);
+ auto *opStmt = cast<OperationInst>(stmt);
if (auto write = opStmt->dyn_cast<VectorTransferWriteOp>()) {
instantiate(&b, write, state->hwVectorType, state->hwVectorInstance,
state->substitutionsMap);
@@ -620,8 +620,7 @@ static bool emitSlice(MaterializationState *state,
}
LLVM_DEBUG(dbgs() << "\nMLFunction is now\n");
- LLVM_DEBUG(
- cast<OperationStmt>((*slice)[0])->getOperationFunction()->print(dbgs()));
+ LLVM_DEBUG(cast<OperationInst>((*slice)[0])->getFunction()->print(dbgs()));
// slice are topologically sorted, we can just erase them in reverse
// order. Reverse iterator does not just work simply with an operator*
@@ -652,7 +651,7 @@ static bool emitSlice(MaterializationState *state,
/// because we currently disallow vectorization of defs that come from another
/// scope.
static bool materialize(MLFunction *f,
- const SetVector<OperationStmt *> &terminators,
+ const SetVector<OperationInst *> &terminators,
MaterializationState *state) {
DenseSet<Statement *> seen;
for (auto *term : terminators) {
@@ -724,7 +723,7 @@ PassResult MaterializeVectorsPass::runOnMLFunction(MLFunction *f) {
// Capture terminators; i.e. vector_transfer_write ops involving a strict
// super-vector of subVectorType.
auto filter = [subVectorType](const Statement &stmt) {
- const auto &opStmt = cast<OperationStmt>(stmt);
+ const auto &opStmt = cast<OperationInst>(stmt);
if (!opStmt.isa<VectorTransferWriteOp>()) {
return false;
}
@@ -732,9 +731,9 @@ PassResult MaterializeVectorsPass::runOnMLFunction(MLFunction *f) {
};
auto pat = Op(filter);
auto matches = pat.match(f);
- SetVector<OperationStmt *> terminators;
+ SetVector<OperationInst *> terminators;
for (auto m : matches) {
- terminators.insert(cast<OperationStmt>(m.first));
+ terminators.insert(cast<OperationInst>(m.first));
}
auto fail = materialize(f, terminators, &state);
diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp
index ce2fac72933..0096cd7be2d 100644
--- a/mlir/lib/Transforms/PipelineDataTransfer.cpp
+++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp
@@ -64,7 +64,7 @@ FunctionPass *mlir::createPipelineDataTransferPass() {
// Returns the position of the tag memref operand given a DMA statement.
// Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are
// added. TODO(b/117228571)
-static unsigned getTagMemRefPos(const OperationStmt &dmaStmt) {
+static unsigned getTagMemRefPos(const OperationInst &dmaStmt) {
assert(dmaStmt.isa<DmaStartOp>() || dmaStmt.isa<DmaWaitOp>());
if (dmaStmt.isa<DmaStartOp>()) {
// Second to last operand.
@@ -179,13 +179,13 @@ static bool checkTagMatch(OpPointer<DmaStartOp> startOp,
// Identify matching DMA start/finish statements to overlap computation with.
static void findMatchingStartFinishStmts(
ForStmt *forStmt,
- SmallVectorImpl<std::pair<OperationStmt *, OperationStmt *>>
+ SmallVectorImpl<std::pair<OperationInst *, OperationInst *>>
&startWaitPairs) {
// Collect outgoing DMA statements - needed to check for dependences below.
SmallVector<OpPointer<DmaStartOp>, 4> outgoingDmaOps;
for (auto &stmt : *forStmt->getBody()) {
- auto *opStmt = dyn_cast<OperationStmt>(&stmt);
+ auto *opStmt = dyn_cast<OperationInst>(&stmt);
if (!opStmt)
continue;
OpPointer<DmaStartOp> dmaStartOp;
@@ -194,9 +194,9 @@ static void findMatchingStartFinishStmts(
outgoingDmaOps.push_back(dmaStartOp);
}
- SmallVector<OperationStmt *, 4> dmaStartStmts, dmaFinishStmts;
+ SmallVector<OperationInst *, 4> dmaStartStmts, dmaFinishStmts;
for (auto &stmt : *forStmt->getBody()) {
- auto *opStmt = dyn_cast<OperationStmt>(&stmt);
+ auto *opStmt = dyn_cast<OperationInst>(&stmt);
if (!opStmt)
continue;
// Collect DMA finish statements.
@@ -260,7 +260,7 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) {
return success();
}
- SmallVector<std::pair<OperationStmt *, OperationStmt *>, 4> startWaitPairs;
+ SmallVector<std::pair<OperationInst *, OperationInst *>, 4> startWaitPairs;
findMatchingStartFinishStmts(forStmt, startWaitPairs);
if (startWaitPairs.empty()) {
@@ -293,7 +293,7 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) {
// operation could have been used on it if it was dynamically shaped in
// order to create the double buffer above)
if (oldMemRef->use_empty())
- if (auto *allocStmt = oldMemRef->getDefiningStmt())
+ if (auto *allocStmt = oldMemRef->getDefiningInst())
allocStmt->erase();
}
@@ -309,7 +309,7 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) {
// If the old tag has no more uses, remove its 'dead' alloc if it was
// alloc'ed.
if (oldTagMemRef->use_empty())
- if (auto *allocStmt = oldTagMemRef->getDefiningStmt())
+ if (auto *allocStmt = oldTagMemRef->getDefiningInst())
allocStmt->erase();
}
@@ -329,7 +329,7 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) {
} else {
// If a slice wasn't created, the reachable affine_apply op's from its
// operands are the ones that go with it.
- SmallVector<OperationStmt *, 4> affineApplyStmts;
+ SmallVector<OperationInst *, 4> affineApplyStmts;
SmallVector<Value *, 4> operands(dmaStartStmt->getOperands());
getReachableAffineApplyOps(operands, affineApplyStmts);
for (const auto *stmt : affineApplyStmts) {
@@ -352,7 +352,7 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) {
shifts[s++] = stmtShiftMap[&stmt];
LLVM_DEBUG(
// Tagging statements with shifts for debugging purposes.
- if (auto *opStmt = dyn_cast<OperationStmt>(&stmt)) {
+ if (auto *opStmt = dyn_cast<OperationInst>(&stmt)) {
FuncBuilder b(opStmt);
opStmt->setAttr(b.getIdentifier("shift"),
b.getI64IntegerAttr(shifts[s - 1]));
diff --git a/mlir/lib/Transforms/SimplifyAffineExpr.cpp b/mlir/lib/Transforms/SimplifyAffineExpr.cpp
index 048e26ae115..b0b31e01175 100644
--- a/mlir/lib/Transforms/SimplifyAffineExpr.cpp
+++ b/mlir/lib/Transforms/SimplifyAffineExpr.cpp
@@ -47,7 +47,7 @@ struct SimplifyAffineStructures : public FunctionPass,
PassResult runOnCFGFunction(CFGFunction *f) override { return success(); }
void visitIfStmt(IfStmt *ifStmt);
- void visitOperationStmt(OperationStmt *opStmt);
+ void visitOperationInst(OperationInst *opStmt);
static char passID;
};
@@ -75,7 +75,7 @@ void SimplifyAffineStructures::visitIfStmt(IfStmt *ifStmt) {
ifStmt->setIntegerSet(simplifyIntegerSet(set));
}
-void SimplifyAffineStructures::visitOperationStmt(OperationStmt *opStmt) {
+void SimplifyAffineStructures::visitOperationInst(OperationInst *opStmt) {
for (auto attr : opStmt->getAttrs()) {
if (auto mapAttr = attr.second.dyn_cast<AffineMapAttr>()) {
MutableAffineMap mMap(mapAttr.getValue());
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index a690844f7a6..f493e4b090b 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -39,7 +39,7 @@ public:
void simplifyFunction(Function *currentFunction, WorklistRewriter &rewriter);
- void addToWorklist(Operation *op) {
+ void addToWorklist(OperationInst *op) {
// Check to see if the worklist already contains this op.
if (worklistMap.count(op))
return;
@@ -48,7 +48,7 @@ public:
worklist.push_back(op);
}
- Operation *popFromWorklist() {
+ OperationInst *popFromWorklist() {
auto *op = worklist.back();
worklist.pop_back();
@@ -60,7 +60,7 @@ public:
/// If the specified operation is in the worklist, remove it. If not, this is
/// a no-op.
- void removeFromWorklist(Operation *op) {
+ void removeFromWorklist(OperationInst *op) {
auto it = worklistMap.find(op);
if (it != worklistMap.end()) {
assert(worklist[it->second] == op && "malformed worklist data structure");
@@ -76,13 +76,13 @@ private:
/// need to be revisited, plus their index in the worklist. This allows us to
/// efficiently remove operations from the worklist when they are removed even
/// if they aren't the root of a pattern.
- std::vector<Operation *> worklist;
- DenseMap<Operation *, unsigned> worklistMap;
+ std::vector<OperationInst *> worklist;
+ DenseMap<OperationInst *, unsigned> worklistMap;
/// As part of canonicalization, we move constants to the top of the entry
/// block of the current function and de-duplicate them. This keeps track of
/// constants we have done this for.
- DenseMap<std::pair<Attribute, Type>, Operation *> uniquedConstants;
+ DenseMap<std::pair<Attribute, Type>, OperationInst *> uniquedConstants;
};
}; // end anonymous namespace
@@ -94,22 +94,22 @@ public:
WorklistRewriter(GreedyPatternRewriteDriver &driver, MLIRContext *context)
: PatternRewriter(context), driver(driver) {}
- virtual void setInsertionPoint(Operation *op) = 0;
+ virtual void setInsertionPoint(OperationInst *op) = 0;
// If an operation is about to be removed, make sure it is not in our
// worklist anymore because we'd get dangling references to it.
- void notifyOperationRemoved(Operation *op) override {
+ void notifyOperationRemoved(OperationInst *op) override {
driver.removeFromWorklist(op);
}
// When the root of a pattern is about to be replaced, it can trigger
// simplifications to its users - make sure to add them to the worklist
// before the root is changed.
- void notifyRootReplaced(Operation *op) override {
+ void notifyRootReplaced(OperationInst *op) override {
for (auto *result : op->getResults())
// TODO: Add a result->getUsers() iterator.
for (auto &user : result->getUses()) {
- if (auto *op = dyn_cast<Operation>(user.getOwner()))
+ if (auto *op = dyn_cast<OperationInst>(user.getOwner()))
driver.addToWorklist(op);
}
@@ -168,7 +168,6 @@ void GreedyPatternRewriteDriver::simplifyFunction(Function *currentFunction,
// canonical version. To ensure safe dominance, move the operation to the
// top of the function.
entry = op;
-
auto &entryBB = currentFunction->front();
op->moveBefore(&entryBB, entryBB.begin());
continue;
@@ -186,7 +185,7 @@ void GreedyPatternRewriteDriver::simplifyFunction(Function *currentFunction,
operandConstants.clear();
for (auto *operand : op->getOperands()) {
Attribute operandCst;
- if (auto *operandOp = operand->getDefiningOperation()) {
+ if (auto *operandOp = operand->getDefiningInst()) {
if (auto operandConstantOp = operandOp->dyn_cast<ConstantOp>())
operandCst = operandConstantOp->getValue();
}
@@ -219,7 +218,7 @@ void GreedyPatternRewriteDriver::simplifyFunction(Function *currentFunction,
//
// TODO: Add a result->getUsers() iterator.
for (auto &operand : op->getResult(i)->getUses()) {
- if (auto *op = dyn_cast<Operation>(operand.getOwner()))
+ if (auto *op = dyn_cast<OperationInst>(operand.getOwner()))
addToWorklist(op);
}
@@ -265,15 +264,15 @@ static void processMLFunction(MLFunction *fn,
// Implement the hook for creating operations, and make sure that newly
// created ops are added to the worklist for processing.
- Operation *createOperation(const OperationState &state) override {
+ OperationInst *createOperation(const OperationState &state) override {
auto *result = builder.createOperation(state);
driver.addToWorklist(result);
return result;
}
- void setInsertionPoint(Operation *op) override {
+ void setInsertionPoint(OperationInst *op) override {
// Any new operations should be added before this statement.
- builder.setInsertionPoint(cast<OperationStmt>(op));
+ builder.setInsertionPoint(cast<OperationInst>(op));
}
private:
@@ -281,7 +280,7 @@ static void processMLFunction(MLFunction *fn,
};
GreedyPatternRewriteDriver driver(std::move(patterns));
- fn->walk([&](OperationStmt *stmt) { driver.addToWorklist(stmt); });
+ fn->walk([&](OperationInst *stmt) { driver.addToWorklist(stmt); });
FuncBuilder mlBuilder(fn);
MLFuncRewriter rewriter(driver, mlBuilder);
@@ -297,13 +296,13 @@ static void processCFGFunction(CFGFunction *fn,
// Implement the hook for creating operations, and make sure that newly
// created ops are added to the worklist for processing.
- Operation *createOperation(const OperationState &state) override {
+ OperationInst *createOperation(const OperationState &state) override {
auto *result = builder.createOperation(state);
driver.addToWorklist(result);
return result;
}
- void setInsertionPoint(Operation *op) override {
+ void setInsertionPoint(OperationInst *op) override {
// Any new operations should be added before this instruction.
builder.setInsertionPoint(cast<OperationInst>(op));
}
@@ -315,7 +314,7 @@ static void processCFGFunction(CFGFunction *fn,
GreedyPatternRewriteDriver driver(std::move(patterns));
for (auto &bb : *fn)
for (auto &op : bb)
- if (auto *opInst = dyn_cast<OperationStmt>(&op))
+ if (auto *opInst = dyn_cast<OperationInst>(&op))
driver.addToWorklist(opInst);
FuncBuilder cfgBuilder(fn);
diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp
index b92e15d7857..4a2831c0a83 100644
--- a/mlir/lib/Transforms/Utils/LoopUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp
@@ -156,7 +156,7 @@ generateLoop(AffineMap lbMap, AffineMap ubMap,
auto *loopChunk = b->createFor(srcForStmt->getLoc(), lbOperands, lbMap,
ubOperands, ubMap, srcForStmt->getStep());
- OperationStmt::OperandMapTy operandMap;
+ OperationInst::OperandMapTy operandMap;
for (auto it = stmtGroupQueue.begin() + offset, e = stmtGroupQueue.end();
it != e; ++it) {
diff --git a/mlir/lib/Transforms/Utils/LoweringUtils.cpp b/mlir/lib/Transforms/Utils/LoweringUtils.cpp
index 90f4d0c028d..6fca54a9972 100644
--- a/mlir/lib/Transforms/Utils/LoweringUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoweringUtils.cpp
@@ -124,7 +124,7 @@ bool mlir::expandAffineApply(AffineApplyOp *op) {
if (!op)
return true;
- FuncBuilder builder(cast<OperationStmt>(op->getOperation()));
+ FuncBuilder builder(cast<OperationInst>(op->getOperation()));
auto affineMap = op->getAffineMap();
for (auto numberedExpr : llvm::enumerate(affineMap.getResults())) {
Value *expanded = expandAffineExpr(&builder, numberedExpr.value(), op);
diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp
index 2bc1be1b785..c8317c27f74 100644
--- a/mlir/lib/Transforms/Utils/Utils.cpp
+++ b/mlir/lib/Transforms/Utils/Utils.cpp
@@ -36,7 +36,7 @@ using namespace mlir;
/// Return true if this operation dereferences one or more memref's.
// Temporary utility: will be replaced when this is modeled through
// side-effects/op traits. TODO(b/117228571)
-static bool isMemRefDereferencingOp(const Operation &op) {
+static bool isMemRefDereferencingOp(const OperationInst &op) {
if (op.isa<LoadOp>() || op.isa<StoreOp>() || op.isa<DmaStartOp>() ||
op.isa<DmaWaitOp>())
return true;
@@ -82,10 +82,10 @@ bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef,
assert(oldMemRef->getType().cast<MemRefType>().getElementType() ==
newMemRef->getType().cast<MemRefType>().getElementType());
- // Walk all uses of old memref. Statement using the memref gets replaced.
+ // Walk all uses of old memref. Operation using the memref gets replaced.
for (auto it = oldMemRef->use_begin(); it != oldMemRef->use_end();) {
StmtOperand &use = *(it++);
- auto *opStmt = cast<OperationStmt>(use.getOwner());
+ auto *opStmt = cast<OperationInst>(use.getOwner());
// Skip this use if it's not dominated by domStmtFilter.
if (domStmtFilter && !dominates(*domStmtFilter, *opStmt))
@@ -124,7 +124,7 @@ bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef,
// TODO(mlir-team): An operation/SSA value should provide a method to
// return the position of an SSA result in its defining
// operation.
- assert(extraIndex->getDefiningStmt()->getNumResults() == 1 &&
+ assert(extraIndex->getDefiningInst()->getNumResults() == 1 &&
"single result op's expected to generate these indices");
assert((extraIndex->isValidDim() || extraIndex->isValidSymbol()) &&
"invalid memory op index");
@@ -186,10 +186,10 @@ bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef,
// operands were drawing results from multiple affine apply ops, this also leads
// to a collapse into a single affine apply op. The final results of the
// composed AffineApplyOp are returned in output parameter 'results'.
-OperationStmt *
+OperationInst *
mlir::createComposedAffineApplyOp(FuncBuilder *builder, Location loc,
ArrayRef<Value *> operands,
- ArrayRef<OperationStmt *> affineApplyOps,
+ ArrayRef<OperationInst *> affineApplyOps,
SmallVectorImpl<Value *> *results) {
// Create identity map with same number of dimensions as number of operands.
auto map = builder->getMultiDimIdentityMap(operands.size());
@@ -216,7 +216,7 @@ mlir::createComposedAffineApplyOp(FuncBuilder *builder, Location loc,
for (unsigned i = 0, e = operands.size(); i < e; ++i) {
(*results)[i] = affineApplyOp->getResult(i);
}
- return cast<OperationStmt>(affineApplyOp->getOperation());
+ return cast<OperationInst>(affineApplyOp->getOperation());
}
/// Given an operation statement, inserts a new single affine apply operation,
@@ -247,19 +247,19 @@ mlir::createComposedAffineApplyOp(FuncBuilder *builder, Location loc,
/// all the affine_apply op's supplying operands to this opStmt do not have any
/// uses besides this opStmt. Returns the new affine_apply operation statement
/// otherwise.
-OperationStmt *mlir::createAffineComputationSlice(OperationStmt *opStmt) {
+OperationInst *mlir::createAffineComputationSlice(OperationInst *opStmt) {
// Collect all operands that are results of affine apply ops.
SmallVector<Value *, 4> subOperands;
subOperands.reserve(opStmt->getNumOperands());
for (auto *operand : opStmt->getOperands()) {
- auto *defStmt = operand->getDefiningStmt();
+ auto *defStmt = operand->getDefiningInst();
if (defStmt && defStmt->isa<AffineApplyOp>()) {
subOperands.push_back(operand);
}
}
// Gather sequence of AffineApplyOps reachable from 'subOperands'.
- SmallVector<OperationStmt *, 4> affineApplyOps;
+ SmallVector<OperationInst *, 4> affineApplyOps;
getReachableAffineApplyOps(subOperands, affineApplyOps);
// Skip transforming if there are no affine maps to compose.
if (affineApplyOps.empty())
@@ -313,11 +313,11 @@ OperationStmt *mlir::createAffineComputationSlice(OperationStmt *opStmt) {
}
void mlir::forwardSubstitute(OpPointer<AffineApplyOp> affineApplyOp) {
- if (!affineApplyOp->getOperation()->getOperationFunction()->isML()) {
+ if (!affineApplyOp->getOperation()->getFunction()->isML()) {
// TODO: Support forward substitution for CFG style functions.
return;
}
- auto *opStmt = cast<OperationStmt>(affineApplyOp->getOperation());
+ auto *opStmt = cast<OperationInst>(affineApplyOp->getOperation());
// Iterate through all uses of all results of 'opStmt', forward substituting
// into any uses which are AffineApplyOps.
for (unsigned resultIndex = 0, e = opStmt->getNumResults(); resultIndex < e;
@@ -326,7 +326,7 @@ void mlir::forwardSubstitute(OpPointer<AffineApplyOp> affineApplyOp) {
for (auto it = result->use_begin(); it != result->use_end();) {
StmtOperand &use = *(it++);
auto *useStmt = use.getOwner();
- auto *useOpStmt = dyn_cast<OperationStmt>(useStmt);
+ auto *useOpStmt = dyn_cast<OperationInst>(useStmt);
// Skip if use is not AffineApplyOp.
if (useOpStmt == nullptr || !useOpStmt->isa<AffineApplyOp>())
continue;
@@ -379,7 +379,7 @@ bool mlir::constantFoldBounds(ForStmt *forStmt) {
: forStmt->getUpperBoundOperands();
for (const auto *operand : boundOperands) {
Attribute operandCst;
- if (auto *operandOp = operand->getDefiningOperation()) {
+ if (auto *operandOp = operand->getDefiningInst()) {
if (auto operandConstantOp = operandOp->dyn_cast<ConstantOp>())
operandCst = operandConstantOp->getValue();
}
@@ -415,7 +415,8 @@ bool mlir::constantFoldBounds(ForStmt *forStmt) {
}
void mlir::remapFunctionAttrs(
- Operation &op, const DenseMap<Attribute, FunctionAttr> &remappingTable) {
+ OperationInst &op,
+ const DenseMap<Attribute, FunctionAttr> &remappingTable) {
for (auto attr : op.getAttrs()) {
// Do the remapping, if we got the same thing back, then it must contain
// functions that aren't getting remapped.
@@ -451,7 +452,7 @@ void mlir::remapFunctionAttrs(
struct MLFnWalker : public StmtWalker<MLFnWalker> {
MLFnWalker(const DenseMap<Attribute, FunctionAttr> &remappingTable)
: remappingTable(remappingTable) {}
- void visitOperationStmt(OperationStmt *opStmt) {
+ void visitOperationInst(OperationInst *opStmt) {
remapFunctionAttrs(*opStmt, remappingTable);
}
diff --git a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp
index b8145126770..5abd3a3cfcc 100644
--- a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp
+++ b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp
@@ -98,7 +98,7 @@ void VectorizerTestPass::testVectorShapeRatio(MLFunction *f) {
// Only filter statements that operate on a strict super-vector and have one
// return. This makes testing easier.
auto filter = [subVectorType](const Statement &stmt) {
- auto *opStmt = dyn_cast<OperationStmt>(&stmt);
+ auto *opStmt = dyn_cast<OperationInst>(&stmt);
if (!opStmt) {
return false;
}
@@ -116,7 +116,7 @@ void VectorizerTestPass::testVectorShapeRatio(MLFunction *f) {
auto pat = Op(filter);
auto matches = pat.match(f);
for (auto m : matches) {
- auto *opStmt = cast<OperationStmt>(m.first);
+ auto *opStmt = cast<OperationInst>(m.first);
// This is a unit test that only checks and prints shape ratio.
// As a consequence we write only Ops with a single return type for the
// purpose of this test. If we need to test more intricate behavior in the
@@ -146,7 +146,7 @@ static MLFunctionMatches matchTestSlicingOps(MLFunction *f) {
using matcher::Op;
// Match all OpStatements with the kTestSlicingOpName name.
auto filter = [](const Statement &stmt) {
- const auto &opStmt = cast<OperationStmt>(stmt);
+ const auto &opStmt = cast<OperationInst>(stmt);
return opStmt.getName().getStringRef() == kTestSlicingOpName;
};
auto pat = Op(filter);
@@ -192,7 +192,7 @@ void VectorizerTestPass::testSlicing(MLFunction *f) {
}
bool customOpWithAffineMapAttribute(const Statement &stmt) {
- const auto &opStmt = cast<OperationStmt>(stmt);
+ const auto &opStmt = cast<OperationInst>(stmt);
return opStmt.getName().getStringRef() ==
VectorizerTestPass::kTestAffineMapOpName;
}
@@ -205,7 +205,7 @@ void VectorizerTestPass::testComposeMaps(MLFunction *f) {
maps.reserve(matches.size());
std::reverse(matches.begin(), matches.end());
for (auto m : matches) {
- auto *opStmt = cast<OperationStmt>(m.first);
+ auto *opStmt = cast<OperationInst>(m.first);
auto map = opStmt->getAttr(VectorizerTestPass::kTestAffineMapAttrName)
.cast<AffineMapAttr>()
.getValue();
diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp
index 80d16475e47..0efe727f5b4 100644
--- a/mlir/lib/Transforms/Vectorize.cpp
+++ b/mlir/lib/Transforms/Vectorize.cpp
@@ -722,22 +722,22 @@ namespace {
struct VectorizationState {
/// Adds an entry of pre/post vectorization statements in the state.
- void registerReplacement(OperationStmt *key, OperationStmt *value);
+ void registerReplacement(OperationInst *key, OperationInst *value);
/// When the current vectorization pattern is successful, this erases the
/// instructions that were marked for erasure in the proper order and resets
/// the internal state for the next pattern.
void finishVectorizationPattern();
- // In-order tracking of original OperationStmt that have been vectorized.
+ // In-order tracking of original OperationInst that have been vectorized.
// Erase in reverse order.
- SmallVector<OperationStmt *, 16> toErase;
- // Set of OperationStmt that have been vectorized (the values in the
+ SmallVector<OperationInst *, 16> toErase;
+ // Set of OperationInst that have been vectorized (the values in the
// vectorizationMap for hashed access). The vectorizedSet is used in
// particular to filter the statements that have already been vectorized by
// this pattern, when iterating over nested loops in this pattern.
- DenseSet<OperationStmt *> vectorizedSet;
- // Map of old scalar OperationStmt to new vectorized OperationStmt.
- DenseMap<OperationStmt *, OperationStmt *> vectorizationMap;
+ DenseSet<OperationInst *> vectorizedSet;
+ // Map of old scalar OperationInst to new vectorized OperationInst.
+ DenseMap<OperationInst *, OperationInst *> vectorizationMap;
// Map of old scalar Value to new vectorized Value.
DenseMap<const Value *, Value *> replacementMap;
// The strategy drives which loop to vectorize by which amount.
@@ -746,17 +746,17 @@ struct VectorizationState {
// vectorizeOperations function. They consist of the subset of load operations
// that have been vectorized. They can be retrieved from `vectorizationMap`
// but it is convenient to keep track of them in a separate data structure.
- DenseSet<OperationStmt *> roots;
+ DenseSet<OperationInst *> roots;
// Terminator statements for the worklist in the vectorizeOperations function.
// They consist of the subset of store operations that have been vectorized.
// They can be retrieved from `vectorizationMap` but it is convenient to keep
// track of them in a separate data structure. Since they do not necessarily
// belong to use-def chains starting from loads (e.g storing a constant), we
// need to handle them in a post-pass.
- DenseSet<OperationStmt *> terminators;
+ DenseSet<OperationInst *> terminators;
// Checks that the type of `stmt` is StoreOp and adds it to the terminators
// set.
- void registerTerminator(OperationStmt *stmt);
+ void registerTerminator(OperationInst *stmt);
private:
void registerReplacement(const Value *key, Value *value);
@@ -764,8 +764,8 @@ private:
} // end namespace
-void VectorizationState::registerReplacement(OperationStmt *key,
- OperationStmt *value) {
+void VectorizationState::registerReplacement(OperationInst *key,
+ OperationInst *value) {
LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ commit vectorized op: ");
LLVM_DEBUG(key->print(dbgs()));
LLVM_DEBUG(dbgs() << " into ");
@@ -784,7 +784,7 @@ void VectorizationState::registerReplacement(OperationStmt *key,
}
}
-void VectorizationState::registerTerminator(OperationStmt *stmt) {
+void VectorizationState::registerTerminator(OperationInst *stmt) {
assert(stmt->isa<StoreOp>() && "terminator must be a StoreOp");
assert(terminators.count(stmt) == 0 &&
"terminator was already inserted previously");
@@ -832,7 +832,7 @@ static bool vectorizeRootOrTerminal(Value *iv, LoadOrStoreOpPointer memoryOp,
auto vectorType = VectorType::get(state->strategy->vectorSizes, elementType);
// Materialize a MemRef with 1 vector.
- auto *opStmt = cast<OperationStmt>(memoryOp->getOperation());
+ auto *opStmt = cast<OperationInst>(memoryOp->getOperation());
// For now, vector_transfers must be aligned, operate only on indices with an
// identity subset of AffineMap and do not change layout.
// TODO(ntv): increase the expressiveness power of vector_transfer operations
@@ -847,7 +847,7 @@ static bool vectorizeRootOrTerminal(Value *iv, LoadOrStoreOpPointer memoryOp,
opStmt->getLoc(), vectorType, memoryOp->getMemRef(),
map(makePtrDynCaster<Value>(), memoryOp->getIndices()), permutationMap);
state->registerReplacement(opStmt,
- cast<OperationStmt>(transfer->getOperation()));
+ cast<OperationInst>(transfer->getOperation()));
} else {
state->registerTerminator(opStmt);
}
@@ -866,7 +866,7 @@ static bool vectorizeForStmt(ForStmt *loop, int64_t step,
if (!matcher::isLoadOrStore(stmt)) {
return false;
}
- auto *opStmt = cast<OperationStmt>(&stmt);
+ auto *opStmt = cast<OperationInst>(&stmt);
return state->vectorizationMap.count(opStmt) == 0 &&
state->vectorizedSet.count(opStmt) == 0 &&
state->roots.count(opStmt) == 0 &&
@@ -875,7 +875,7 @@ static bool vectorizeForStmt(ForStmt *loop, int64_t step,
auto loadAndStores = matcher::Op(notVectorizedThisPattern);
auto matches = loadAndStores.match(loop);
for (auto ls : matches) {
- auto *opStmt = cast<OperationStmt>(ls.first);
+ auto *opStmt = cast<OperationInst>(ls.first);
auto load = opStmt->dyn_cast<LoadOp>();
auto store = opStmt->dyn_cast<StoreOp>();
LLVM_DEBUG(opStmt->print(dbgs()));
@@ -974,14 +974,14 @@ static Value *vectorizeConstant(Statement *stmt, const ConstantOp &constant,
Location loc = stmt->getLoc();
auto vectorType = type.cast<VectorType>();
auto attr = SplatElementsAttr::get(vectorType, constant.getValue());
- auto *constantOpStmt = cast<OperationStmt>(constant.getOperation());
+ auto *constantOpStmt = cast<OperationInst>(constant.getOperation());
OperationState state(
b.getContext(), loc, constantOpStmt->getName().getStringRef(), {},
{vectorType},
{make_pair(Identifier::get("value", b.getContext()), attr)});
- auto *splat = cast<OperationStmt>(b.createOperation(state));
+ auto *splat = cast<OperationInst>(b.createOperation(state));
return splat->getResult(0);
}
@@ -994,7 +994,7 @@ static Type getVectorType(Value *v, const VectorizationState &state) {
if (!VectorType::isValidElementType(v->getType())) {
return Type();
}
- auto *definingOpStmt = cast<OperationStmt>(v->getDefiningStmt());
+ auto *definingOpStmt = cast<OperationInst>(v->getDefiningInst());
if (state.vectorizedSet.count(definingOpStmt) > 0) {
return v->getType().cast<VectorType>();
}
@@ -1026,7 +1026,7 @@ static Value *vectorizeOperand(Value *operand, Statement *stmt,
VectorizationState *state) {
LLVM_DEBUG(dbgs() << "\n[early-vect]vectorize operand: ");
LLVM_DEBUG(operand->print(dbgs()));
- auto *definingStatement = cast<OperationStmt>(operand->getDefiningStmt());
+ auto *definingStatement = cast<OperationInst>(operand->getDefiningInst());
// 1. If this value has already been vectorized this round, we are done.
if (state->vectorizedSet.count(definingStatement) > 0) {
LLVM_DEBUG(dbgs() << " -> already vector operand");
@@ -1049,7 +1049,7 @@ static Value *vectorizeOperand(Value *operand, Statement *stmt,
return nullptr;
}
// 3. vectorize constant.
- if (auto constant = operand->getDefiningStmt()->dyn_cast<ConstantOp>()) {
+ if (auto constant = operand->getDefiningInst()->dyn_cast<ConstantOp>()) {
return vectorizeConstant(stmt, *constant,
getVectorType(operand, *state).cast<VectorType>());
}
@@ -1059,17 +1059,17 @@ static Value *vectorizeOperand(Value *operand, Statement *stmt,
return nullptr;
};
-/// Encodes OperationStmt-specific behavior for vectorization. In general we
+/// Encodes OperationInst-specific behavior for vectorization. In general we
/// assume that all operands of an op must be vectorized but this is not always
/// true. In the future, it would be nice to have a trait that describes how a
/// particular operation vectorizes. For now we implement the case distinction
/// here.
-/// Returns a vectorized form of stmt or nullptr if vectorization fails.
+/// Returns a vectorized form of an operation or nullptr if vectorization fails.
/// TODO(ntv): consider adding a trait to Op to describe how it gets vectorized.
/// Maybe some Ops are not vectorizable or require some tricky logic, we cannot
/// do one-off logic here; ideally it would be TableGen'd.
-static OperationStmt *vectorizeOneOperationStmt(FuncBuilder *b,
- OperationStmt *opStmt,
+static OperationInst *vectorizeOneOperationInst(FuncBuilder *b,
+ OperationInst *opStmt,
VectorizationState *state) {
// Sanity checks.
assert(!opStmt->isa<LoadOp>() &&
@@ -1091,7 +1091,7 @@ static OperationStmt *vectorizeOneOperationStmt(FuncBuilder *b,
LLVM_DEBUG(permutationMap.print(dbgs()));
auto transfer = b.create<VectorTransferWriteOp>(
opStmt->getLoc(), vectorValue, memRef, indices, permutationMap);
- auto *res = cast<OperationStmt>(transfer->getOperation());
+ auto *res = cast<OperationInst>(transfer->getOperation());
LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ vectorized store: " << *res);
// "Terminators" (i.e. StoreOps) are erased on the spot.
opStmt->erase();
@@ -1114,8 +1114,8 @@ static OperationStmt *vectorizeOneOperationStmt(FuncBuilder *b,
// Create a clone of the op with the proper operands and return types.
// TODO(ntv): The following assumes there is always an op with a fixed
// name that works both in scalar mode and vector mode.
- // TODO(ntv): Is it worth considering an OperationStmt.clone operation
- // which changes the type so we can promote an OperationStmt with less
+ // TODO(ntv): Is it worth considering an OperationInst.clone operation
+ // which changes the type so we can promote an OperationInst with less
// boilerplate?
OperationState newOp(b->getContext(), opStmt->getLoc(),
opStmt->getName().getStringRef(), operands, types,
@@ -1123,22 +1123,22 @@ static OperationStmt *vectorizeOneOperationStmt(FuncBuilder *b,
return b->createOperation(newOp);
}
-/// Iterates over the OperationStmt in the loop and rewrites them using their
+/// Iterates over the OperationInst in the loop and rewrites them using their
/// vectorized counterpart by:
-/// 1. iteratively building a worklist of uses of the OperationStmt vectorized
+/// 1. iteratively building a worklist of uses of the OperationInst vectorized
/// so far by this pattern;
-/// 2. for each OperationStmt in the worklist, create the vector form of this
+/// 2. for each OperationInst in the worklist, create the vector form of this
/// operation and replace all its uses by the vectorized form. For this step,
/// the worklist must be traversed in order;
/// 3. verify that all operands of the newly vectorized operation have been
/// vectorized by this pattern.
static bool vectorizeOperations(VectorizationState *state) {
// 1. create initial worklist with the uses of the roots.
- SetVector<OperationStmt *> worklist;
- auto insertUsesOf = [&worklist, state](Operation *vectorized) {
- for (auto *r : cast<OperationStmt>(vectorized)->getResults())
+ SetVector<OperationInst *> worklist;
+ auto insertUsesOf = [&worklist, state](OperationInst *vectorized) {
+ for (auto *r : vectorized->getResults())
for (auto &u : r->getUses()) {
- auto *stmt = cast<OperationStmt>(u.getOwner());
+ auto *stmt = cast<OperationInst>(u.getOwner());
// Don't propagate to terminals, a separate pass is needed for those.
// TODO(ntv)[b/119759136]: use isa<> once Op is implemented.
if (state->terminators.count(stmt) > 0) {
@@ -1160,7 +1160,7 @@ static bool vectorizeOperations(VectorizationState *state) {
// 2. Create vectorized form of the statement.
// Insert it just before stmt, on success register stmt as replaced.
FuncBuilder b(stmt);
- auto *vectorizedStmt = vectorizeOneOperationStmt(&b, stmt, state);
+ auto *vectorizedStmt = vectorizeOneOperationInst(&b, stmt, state);
if (!vectorizedStmt) {
return true;
}
@@ -1169,11 +1169,11 @@ static bool vectorizeOperations(VectorizationState *state) {
// Note that we cannot just call replaceAllUsesWith because it may
// result in ops with mixed types, for ops whose operands have not all
// yet been vectorized. This would be invalid IR.
- state->registerReplacement(cast<OperationStmt>(stmt), vectorizedStmt);
+ state->registerReplacement(stmt, vectorizedStmt);
// 4. Augment the worklist with uses of the statement we just vectorized.
// This preserves the proper order in the worklist.
- apply(insertUsesOf, ArrayRef<Operation *>{stmt});
+ apply(insertUsesOf, ArrayRef<OperationInst *>{stmt});
}
return false;
}
@@ -1223,12 +1223,12 @@ static bool vectorizeRootMatches(MLFunctionMatches matches,
// Form the root operationsthat have been set in the replacementMap.
// For now, these roots are the loads for which vector_transfer_read
// operations have been inserted.
- auto getDefiningOperation = [](const Value *val) {
- return const_cast<Value *>(val)->getDefiningOperation();
+ auto getDefiningInst = [](const Value *val) {
+ return const_cast<Value *>(val)->getDefiningInst();
};
using ReferenceTy = decltype(*(state.replacementMap.begin()));
auto getKey = [](ReferenceTy it) { return it.first; };
- auto roots = map(getDefiningOperation, map(getKey, state.replacementMap));
+ auto roots = map(getDefiningInst, map(getKey, state.replacementMap));
// Vectorize the root operations and everything reached by use-def chains
// except the terminators (store statements) that need to be post-processed
@@ -1240,12 +1240,12 @@ static bool vectorizeRootMatches(MLFunctionMatches matches,
}
// Finally, vectorize the terminators. If anything fails to vectorize, skip.
- auto vectorizeOrFail = [&fail, &state](OperationStmt *stmt) {
+ auto vectorizeOrFail = [&fail, &state](OperationInst *stmt) {
if (fail) {
return;
}
FuncBuilder b(stmt);
- auto *res = vectorizeOneOperationStmt(&b, stmt, &state);
+ auto *res = vectorizeOneOperationInst(&b, stmt, &state);
if (res == nullptr) {
fail = true;
}
OpenPOWER on IntegriCloud