summaryrefslogtreecommitdiffstats
path: root/mlir/lib
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/AffineOps/AffineOps.cpp151
-rw-r--r--mlir/lib/AffineOps/DialectRegistration.cpp22
-rw-r--r--mlir/lib/Analysis/LoopAnalysis.cpp11
-rw-r--r--mlir/lib/Analysis/NestedMatcher.cpp20
-rw-r--r--mlir/lib/Analysis/Utils.cpp22
-rw-r--r--mlir/lib/Analysis/Verifier.cpp25
-rw-r--r--mlir/lib/IR/AsmPrinter.cpp38
-rw-r--r--mlir/lib/IR/Builders.cpp7
-rw-r--r--mlir/lib/IR/Instruction.cpp109
-rw-r--r--mlir/lib/IR/Operation.cpp2
-rw-r--r--mlir/lib/IR/Value.cpp2
-rw-r--r--mlir/lib/Parser/Parser.cpp129
-rw-r--r--mlir/lib/Parser/TokenKinds.def2
-rw-r--r--mlir/lib/Transforms/CSE.cpp10
-rw-r--r--mlir/lib/Transforms/LoopFusion.cpp24
-rw-r--r--mlir/lib/Transforms/LoopUnroll.cpp9
-rw-r--r--mlir/lib/Transforms/LowerAffine.cpp59
-rw-r--r--mlir/lib/Transforms/MaterializeVectors.cpp7
-rw-r--r--mlir/lib/Transforms/SimplifyAffineStructures.cpp37
19 files changed, 344 insertions, 342 deletions
diff --git a/mlir/lib/AffineOps/AffineOps.cpp b/mlir/lib/AffineOps/AffineOps.cpp
new file mode 100644
index 00000000000..5b29467fc44
--- /dev/null
+++ b/mlir/lib/AffineOps/AffineOps.cpp
@@ -0,0 +1,151 @@
+//===- AffineOps.cpp - MLIR Affine Operations -----------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/AffineOps/AffineOps.h"
+#include "mlir/IR/Block.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/OpImplementation.h"
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// AffineOpsDialect
+//===----------------------------------------------------------------------===//
+
+AffineOpsDialect::AffineOpsDialect(MLIRContext *context)
+ : Dialect(/*namePrefix=*/"", context) {
+ addOperations<AffineIfOp>();
+}
+
+//===----------------------------------------------------------------------===//
+// AffineIfOp
+//===----------------------------------------------------------------------===//
+
+void AffineIfOp::build(Builder *builder, OperationState *result,
+ IntegerSet condition,
+ ArrayRef<Value *> conditionOperands) {
+ result->addAttribute(getConditionAttrName(), IntegerSetAttr::get(condition));
+ result->addOperands(conditionOperands);
+
+ // Reserve 2 block lists, one for the 'then' and one for the 'else' regions.
+ result->reserveBlockLists(2);
+}
+
+bool AffineIfOp::verify() const {
+ // Verify that we have a condition attribute.
+ auto conditionAttr = getAttrOfType<IntegerSetAttr>(getConditionAttrName());
+ if (!conditionAttr)
+ return emitOpError("requires an integer set attribute named 'condition'");
+
+ // Verify that the operands are valid dimension/symbols.
+ IntegerSet condition = conditionAttr.getValue();
+ for (unsigned i = 0, e = getNumOperands(); i != e; ++i) {
+ const Value *operand = getOperand(i);
+ if (i < condition.getNumDims() && !operand->isValidDim())
+ return emitOpError("operand cannot be used as a dimension id");
+ if (i >= condition.getNumDims() && !operand->isValidSymbol())
+ return emitOpError("operand cannot be used as a symbol");
+ }
+
+ // Verify that the entry of each child blocklist does not have arguments.
+ for (const auto &blockList : getInstruction()->getBlockLists()) {
+ if (blockList.empty())
+ continue;
+
+ // TODO(riverriddle) We currently do not allow multiple blocks in child
+ // block lists.
+ if (std::next(blockList.begin()) != blockList.end())
+ return emitOpError(
+ "expects only one block per 'if' or 'else' block list");
+ if (blockList.front().getTerminator())
+ return emitOpError("expects region block to not have a terminator");
+
+ for (const auto &b : blockList)
+ if (b.getNumArguments() != 0)
+ return emitOpError(
+ "requires that child entry blocks have no arguments");
+ }
+ return false;
+}
+
+bool AffineIfOp::parse(OpAsmParser *parser, OperationState *result) {
+ // Parse the condition attribute set.
+ IntegerSetAttr conditionAttr;
+ unsigned numDims;
+ if (parser->parseAttribute(conditionAttr, getConditionAttrName().data(),
+ result->attributes) ||
+ parseDimAndSymbolList(parser, result->operands, numDims))
+ return true;
+
+ // Verify the condition operands.
+ auto set = conditionAttr.getValue();
+ if (set.getNumDims() != numDims)
+ return parser->emitError(
+ parser->getNameLoc(),
+ "dim operand count and integer set dim count must match");
+ if (numDims + set.getNumSymbols() != result->operands.size())
+ return parser->emitError(
+ parser->getNameLoc(),
+ "symbol operand count and integer set symbol count must match");
+
+ // Parse the 'then' block list.
+ if (parser->parseBlockList())
+ return true;
+
+ // If we find an 'else' keyword then parse the else block list.
+ if (!parser->parseOptionalKeyword("else")) {
+ if (parser->parseBlockList())
+ return true;
+ }
+
+ // Reserve 2 block lists, one for the 'then' and one for the 'else' regions.
+ result->reserveBlockLists(2);
+ return false;
+}
+
+void AffineIfOp::print(OpAsmPrinter *p) const {
+ auto conditionAttr = getAttrOfType<IntegerSetAttr>(getConditionAttrName());
+ *p << "if " << conditionAttr;
+ printDimAndSymbolList(operand_begin(), operand_end(),
+ conditionAttr.getValue().getNumDims(), p);
+ p->printBlockList(getInstruction()->getBlockList(0));
+
+ // Print the 'else' block list if it has any blocks.
+ const auto &elseBlockList = getInstruction()->getBlockList(1);
+ if (!elseBlockList.empty()) {
+ *p << " else";
+ p->printBlockList(elseBlockList);
+ }
+}
+
+IntegerSet AffineIfOp::getIntegerSet() const {
+ return getAttrOfType<IntegerSetAttr>(getConditionAttrName()).getValue();
+}
+void AffineIfOp::setIntegerSet(IntegerSet newSet) {
+ setAttr(
+ Identifier::get(getConditionAttrName(), getInstruction()->getContext()),
+ IntegerSetAttr::get(newSet));
+}
+
+/// Returns the list of 'then' blocks.
+BlockList &AffineIfOp::getThenBlocks() {
+ return getInstruction()->getBlockList(0);
+}
+
+/// Returns the list of 'else' blocks.
+BlockList &AffineIfOp::getElseBlocks() {
+ return getInstruction()->getBlockList(1);
+}
diff --git a/mlir/lib/AffineOps/DialectRegistration.cpp b/mlir/lib/AffineOps/DialectRegistration.cpp
new file mode 100644
index 00000000000..0afb32c1bd6
--- /dev/null
+++ b/mlir/lib/AffineOps/DialectRegistration.cpp
@@ -0,0 +1,22 @@
+//===- DialectRegistration.cpp - Register Affine Op dialect ---------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/AffineOps/AffineOps.h"
+using namespace mlir;
+
+// Static initialization for Affine op dialect registration.
+static DialectRegistration<AffineOpsDialect> StandardOps;
diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp
index 219f356807a..07c903a6613 100644
--- a/mlir/lib/Analysis/LoopAnalysis.cpp
+++ b/mlir/lib/Analysis/LoopAnalysis.cpp
@@ -21,6 +21,7 @@
#include "mlir/Analysis/LoopAnalysis.h"
+#include "mlir/AffineOps/AffineOps.h"
#include "mlir/Analysis/AffineAnalysis.h"
#include "mlir/Analysis/AffineStructures.h"
#include "mlir/Analysis/NestedMatcher.h"
@@ -246,6 +247,16 @@ static bool isVectorizableLoopWithCond(const ForInst &loop,
return false;
}
+ // No vectorization across unknown regions.
+ auto regions = matcher::Op([](const Instruction &inst) -> bool {
+ auto &opInst = cast<OperationInst>(inst);
+ return opInst.getNumBlockLists() != 0 && !opInst.isa<AffineIfOp>();
+ });
+ auto regionsMatched = regions.match(forInst);
+ if (!regionsMatched.empty()) {
+ return false;
+ }
+
auto vectorTransfers = matcher::Op(isVectorTransferReadOrWrite);
auto vectorTransfersMatched = vectorTransfers.match(forInst);
if (!vectorTransfersMatched.empty()) {
diff --git a/mlir/lib/Analysis/NestedMatcher.cpp b/mlir/lib/Analysis/NestedMatcher.cpp
index 4f32e9b22f4..491a9bef1b9 100644
--- a/mlir/lib/Analysis/NestedMatcher.cpp
+++ b/mlir/lib/Analysis/NestedMatcher.cpp
@@ -16,6 +16,7 @@
// =============================================================================
#include "mlir/Analysis/NestedMatcher.h"
+#include "mlir/AffineOps/AffineOps.h"
#include "mlir/StandardOps/StandardOps.h"
#include "llvm/ADT/ArrayRef.h"
@@ -186,6 +187,11 @@ FilterFunctionType NestedPattern::getFilterFunction() {
return storage->filter;
}
+static bool isAffineIfOp(const Instruction &inst) {
+ return isa<OperationInst>(inst) &&
+ cast<OperationInst>(inst).isa<AffineIfOp>();
+}
+
namespace mlir {
namespace matcher {
@@ -194,16 +200,22 @@ NestedPattern Op(FilterFunctionType filter) {
}
NestedPattern If(NestedPattern child) {
- return NestedPattern(Instruction::Kind::If, child, defaultFilterFunction);
+ return NestedPattern(Instruction::Kind::OperationInst, child, isAffineIfOp);
}
NestedPattern If(FilterFunctionType filter, NestedPattern child) {
- return NestedPattern(Instruction::Kind::If, child, filter);
+ return NestedPattern(Instruction::Kind::OperationInst, child,
+ [filter](const Instruction &inst) {
+ return isAffineIfOp(inst) && filter(inst);
+ });
}
NestedPattern If(ArrayRef<NestedPattern> nested) {
- return NestedPattern(Instruction::Kind::If, nested, defaultFilterFunction);
+ return NestedPattern(Instruction::Kind::OperationInst, nested, isAffineIfOp);
}
NestedPattern If(FilterFunctionType filter, ArrayRef<NestedPattern> nested) {
- return NestedPattern(Instruction::Kind::If, nested, filter);
+ return NestedPattern(Instruction::Kind::OperationInst, nested,
+ [filter](const Instruction &inst) {
+ return isAffineIfOp(inst) && filter(inst);
+ });
}
NestedPattern For(NestedPattern child) {
diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp
index 939a2ede618..0e77d4d9084 100644
--- a/mlir/lib/Analysis/Utils.cpp
+++ b/mlir/lib/Analysis/Utils.cpp
@@ -22,6 +22,7 @@
#include "mlir/Analysis/Utils.h"
+#include "mlir/AffineOps/AffineOps.h"
#include "mlir/Analysis/AffineAnalysis.h"
#include "mlir/Analysis/AffineStructures.h"
#include "mlir/IR/Builders.h"
@@ -43,7 +44,7 @@ void mlir::getLoopIVs(const Instruction &inst,
// Traverse up the hierarchy collecing all 'for' instruction while skipping
// over 'if' instructions.
while (currInst && ((currForInst = dyn_cast<ForInst>(currInst)) ||
- isa<IfInst>(currInst))) {
+ cast<OperationInst>(currInst)->isa<AffineIfOp>())) {
if (currForInst)
loops->push_back(currForInst);
currInst = currInst->getParentInst();
@@ -359,21 +360,12 @@ static Instruction *getInstAtPosition(ArrayRef<unsigned> positions,
if (auto *childForInst = dyn_cast<ForInst>(&inst))
return getInstAtPosition(positions, level + 1, childForInst->getBody());
- if (auto *ifInst = dyn_cast<IfInst>(&inst)) {
- auto *ret = getInstAtPosition(positions, level + 1, ifInst->getThen());
- if (ret != nullptr)
- return ret;
- if (auto *elseClause = ifInst->getElse())
- return getInstAtPosition(positions, level + 1, elseClause);
- }
- if (auto *opInst = dyn_cast<OperationInst>(&inst)) {
- for (auto &blockList : opInst->getBlockLists()) {
- for (auto &b : blockList)
- if (auto *ret = getInstAtPosition(positions, level + 1, &b))
- return ret;
- }
- return nullptr;
+ for (auto &blockList : cast<OperationInst>(&inst)->getBlockLists()) {
+ for (auto &b : blockList)
+ if (auto *ret = getInstAtPosition(positions, level + 1, &b))
+ return ret;
}
+ return nullptr;
}
return nullptr;
}
diff --git a/mlir/lib/Analysis/Verifier.cpp b/mlir/lib/Analysis/Verifier.cpp
index 383a4878c35..474eeb2a28e 100644
--- a/mlir/lib/Analysis/Verifier.cpp
+++ b/mlir/lib/Analysis/Verifier.cpp
@@ -73,7 +73,6 @@ public:
bool verifyBlock(const Block &block, bool isTopLevel);
bool verifyOperation(const OperationInst &op);
bool verifyForInst(const ForInst &forInst);
- bool verifyIfInst(const IfInst &ifInst);
bool verifyDominance(const Block &block);
bool verifyInstDominance(const Instruction &inst);
@@ -180,10 +179,6 @@ bool FuncVerifier::verifyBlock(const Block &block, bool isTopLevel) {
if (verifyForInst(cast<ForInst>(inst)))
return true;
break;
- case Instruction::Kind::If:
- if (verifyIfInst(cast<IfInst>(inst)))
- return true;
- break;
}
}
@@ -250,18 +245,6 @@ bool FuncVerifier::verifyForInst(const ForInst &forInst) {
return verifyBlock(*forInst.getBody(), /*isTopLevel=*/false);
}
-bool FuncVerifier::verifyIfInst(const IfInst &ifInst) {
- // TODO: check that if conditions are properly formed.
- if (verifyBlock(*ifInst.getThen(), /*isTopLevel*/ false))
- return true;
-
- if (auto *elseClause = ifInst.getElse())
- if (verifyBlock(*elseClause, /*isTopLevel*/ false))
- return true;
-
- return false;
-}
-
bool FuncVerifier::verifyDominance(const Block &block) {
for (auto &inst : block) {
// Check that all operands on the instruction are ok.
@@ -283,14 +266,6 @@ bool FuncVerifier::verifyDominance(const Block &block) {
if (verifyDominance(*cast<ForInst>(inst).getBody()))
return true;
break;
- case Instruction::Kind::If:
- auto &ifInst = cast<IfInst>(inst);
- if (verifyDominance(*ifInst.getThen()))
- return true;
- if (auto *elseClause = ifInst.getElse())
- if (verifyDominance(*elseClause))
- return true;
- break;
}
}
return false;
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 21bc3b824b1..cb4c1f0edce 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -145,7 +145,6 @@ private:
// Visit functions.
void visitInstruction(const Instruction *inst);
void visitForInst(const ForInst *forInst);
- void visitIfInst(const IfInst *ifInst);
void visitOperationInst(const OperationInst *opInst);
void visitType(Type type);
void visitAttribute(Attribute attr);
@@ -197,10 +196,6 @@ void ModuleState::visitAttribute(Attribute attr) {
}
}
-void ModuleState::visitIfInst(const IfInst *ifInst) {
- recordIntegerSetReference(ifInst->getIntegerSet());
-}
-
void ModuleState::visitForInst(const ForInst *forInst) {
AffineMap lbMap = forInst->getLowerBoundMap();
if (!hasCustomForm(lbMap))
@@ -225,8 +220,6 @@ void ModuleState::visitOperationInst(const OperationInst *op) {
void ModuleState::visitInstruction(const Instruction *inst) {
switch (inst->getKind()) {
- case Instruction::Kind::If:
- return visitIfInst(cast<IfInst>(inst));
case Instruction::Kind::For:
return visitForInst(cast<ForInst>(inst));
case Instruction::Kind::OperationInst:
@@ -1077,7 +1070,6 @@ public:
void print(const Instruction *inst);
void print(const OperationInst *inst);
void print(const ForInst *inst);
- void print(const IfInst *inst);
void print(const Block *block, bool printBlockArgs = true);
void printOperation(const OperationInst *op);
@@ -1125,6 +1117,9 @@ public:
unsigned index) override;
/// Print a block list.
+ void printBlockList(const BlockList &blocks) override {
+ printBlockList(blocks, /*printEntryBlockArgs=*/true);
+ }
void printBlockList(const BlockList &blocks, bool printEntryBlockArgs) {
os << " {\n";
if (!blocks.empty()) {
@@ -1214,12 +1209,6 @@ void FunctionPrinter::numberValuesInBlock(const Block &block) {
// Recursively number the stuff in the body.
numberValuesInBlock(*cast<ForInst>(&inst)->getBody());
break;
- case Instruction::Kind::If: {
- auto *ifInst = cast<IfInst>(&inst);
- numberValuesInBlock(*ifInst->getThen());
- if (auto *elseBlock = ifInst->getElse())
- numberValuesInBlock(*elseBlock);
- }
}
}
}
@@ -1360,8 +1349,7 @@ void FunctionPrinter::printFunctionSignature() {
}
void FunctionPrinter::print(const Block *block, bool printBlockArgs) {
- // Print the block label and argument list, unless this is the first block of
- // the function, or the first block of an IfInst/ForInst with no arguments.
+ // Print the block label and argument list if requested.
if (printBlockArgs) {
os.indent(currentIndent);
printBlockName(block);
@@ -1418,8 +1406,6 @@ void FunctionPrinter::print(const Instruction *inst) {
return print(cast<OperationInst>(inst));
case Instruction::Kind::For:
return print(cast<ForInst>(inst));
- case Instruction::Kind::If:
- return print(cast<IfInst>(inst));
}
}
@@ -1447,22 +1433,6 @@ void FunctionPrinter::print(const ForInst *inst) {
os.indent(currentIndent) << "}";
}
-void FunctionPrinter::print(const IfInst *inst) {
- os.indent(currentIndent) << "if ";
- IntegerSet set = inst->getIntegerSet();
- printIntegerSetReference(set);
- printDimAndSymbolList(inst->getInstOperands(), set.getNumDims());
- printTrailingLocation(inst->getLoc());
- os << " {\n";
- print(inst->getThen(), /*printBlockArgs=*/false);
- os.indent(currentIndent) << "}";
- if (inst->hasElse()) {
- os << " else {\n";
- print(inst->getElse(), /*printBlockArgs=*/false);
- os.indent(currentIndent) << "}";
- }
-}
-
void FunctionPrinter::printValueID(const Value *value,
bool printResultNo) const {
int resultNo = -1;
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 4471ff25e94..e174fdc1d00 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -327,10 +327,3 @@ ForInst *FuncBuilder::createFor(Location location, int64_t lb, int64_t ub,
auto ubMap = AffineMap::getConstantMap(ub, context);
return createFor(location, {}, lbMap, {}, ubMap, step);
}
-
-IfInst *FuncBuilder::createIf(Location location, ArrayRef<Value *> operands,
- IntegerSet set) {
- auto *inst = IfInst::create(location, operands, set);
- block->getInstructions().insert(insertPoint, inst);
- return inst;
-}
diff --git a/mlir/lib/IR/Instruction.cpp b/mlir/lib/IR/Instruction.cpp
index 6d74ed14257..0ccab2305ec 100644
--- a/mlir/lib/IR/Instruction.cpp
+++ b/mlir/lib/IR/Instruction.cpp
@@ -73,9 +73,6 @@ void Instruction::destroy() {
case Kind::For:
delete cast<ForInst>(this);
break;
- case Kind::If:
- delete cast<IfInst>(this);
- break;
}
}
@@ -141,8 +138,6 @@ unsigned Instruction::getNumOperands() const {
return cast<OperationInst>(this)->getNumOperands();
case Kind::For:
return cast<ForInst>(this)->getNumOperands();
- case Kind::If:
- return cast<IfInst>(this)->getNumOperands();
}
}
@@ -152,8 +147,6 @@ MutableArrayRef<InstOperand> Instruction::getInstOperands() {
return cast<OperationInst>(this)->getInstOperands();
case Kind::For:
return cast<ForInst>(this)->getInstOperands();
- case Kind::If:
- return cast<IfInst>(this)->getInstOperands();
}
}
@@ -287,15 +280,6 @@ void Instruction::dropAllReferences() {
// Make sure to drop references held by instructions within the body.
cast<ForInst>(this)->getBody()->dropAllReferences();
break;
- case Kind::If: {
- // Make sure to drop references held by instructions within the 'then' and
- // 'else' blocks.
- auto *ifInst = cast<IfInst>(this);
- ifInst->getThen()->dropAllReferences();
- if (auto *elseBlock = ifInst->getElse())
- elseBlock->dropAllReferences();
- break;
- }
case Kind::OperationInst: {
auto *opInst = cast<OperationInst>(this);
if (isTerminator())
@@ -810,54 +794,6 @@ mlir::extractForInductionVars(ArrayRef<ForInst *> forInsts) {
return results;
}
//===----------------------------------------------------------------------===//
-// IfInst
-//===----------------------------------------------------------------------===//
-
-IfInst::IfInst(Location location, unsigned numOperands, IntegerSet set)
- : Instruction(Kind::If, location), thenClause(this), elseClause(nullptr),
- set(set) {
- operands.reserve(numOperands);
-
- // The then of an 'if' inst always has one block.
- thenClause.push_back(new Block());
-}
-
-IfInst::~IfInst() {
- if (elseClause)
- delete elseClause;
-
- // An IfInst's IntegerSet 'set' should not be deleted since it is
- // allocated through MLIRContext's bump pointer allocator.
-}
-
-IfInst *IfInst::create(Location location, ArrayRef<Value *> operands,
- IntegerSet set) {
- unsigned numOperands = operands.size();
- assert(numOperands == set.getNumOperands() &&
- "operand cound does not match the integer set operand count");
-
- IfInst *inst = new IfInst(location, numOperands, set);
-
- for (auto *op : operands)
- inst->operands.emplace_back(InstOperand(inst, op));
-
- return inst;
-}
-
-const AffineCondition IfInst::getCondition() const {
- return AffineCondition(*this, set);
-}
-
-MLIRContext *IfInst::getContext() const {
- // Check for degenerate case of if instruction with no operands.
- // This is unlikely, but legal.
- if (operands.empty())
- return getFunction()->getContext();
-
- return getOperand(0)->getType().getContext();
-}
-
-//===----------------------------------------------------------------------===//
// Instruction Cloning
//===----------------------------------------------------------------------===//
@@ -931,40 +867,23 @@ Instruction *Instruction::clone(BlockAndValueMapping &mapper,
for (auto *opValue : getOperands())
operands.push_back(mapper.lookupOrDefault(const_cast<Value *>(opValue)));
- if (auto *forInst = dyn_cast<ForInst>(this)) {
- auto lbMap = forInst->getLowerBoundMap();
- auto ubMap = forInst->getUpperBoundMap();
+ // Otherwise, this must be a ForInst.
+ auto *forInst = cast<ForInst>(this);
+ auto lbMap = forInst->getLowerBoundMap();
+ auto ubMap = forInst->getUpperBoundMap();
- auto *newFor = ForInst::create(
- getLoc(), ArrayRef<Value *>(operands).take_front(lbMap.getNumInputs()),
- lbMap, ArrayRef<Value *>(operands).take_back(ubMap.getNumInputs()),
- ubMap, forInst->getStep());
+ auto *newFor = ForInst::create(
+ getLoc(), ArrayRef<Value *>(operands).take_front(lbMap.getNumInputs()),
+ lbMap, ArrayRef<Value *>(operands).take_back(ubMap.getNumInputs()), ubMap,
+ forInst->getStep());
- // Remember the induction variable mapping.
- mapper.map(forInst->getInductionVar(), newFor->getInductionVar());
-
- // Recursively clone the body of the for loop.
- for (auto &subInst : *forInst->getBody())
- newFor->getBody()->push_back(subInst.clone(mapper, context));
-
- return newFor;
- }
-
- // Otherwise, we must have an If instruction.
- auto *ifInst = cast<IfInst>(this);
- auto *newIf = IfInst::create(getLoc(), operands, ifInst->getIntegerSet());
-
- auto *resultThen = newIf->getThen();
- for (auto &childInst : *ifInst->getThen())
- resultThen->push_back(childInst.clone(mapper, context));
-
- if (ifInst->hasElse()) {
- auto *resultElse = newIf->createElse();
- for (auto &childInst : *ifInst->getElse())
- resultElse->push_back(childInst.clone(mapper, context));
- }
+ // Remember the induction variable mapping.
+ mapper.map(forInst->getInductionVar(), newFor->getInductionVar());
- return newIf;
+ // Recursively clone the body of the for loop.
+ for (auto &subInst : *forInst->getBody())
+ newFor->getBody()->push_back(subInst.clone(mapper, context));
+ return newFor;
}
Instruction *Instruction::clone(MLIRContext *context) const {
diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index 099b218892f..2ab151f8913 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -281,7 +281,7 @@ bool OpTrait::impl::verifyIsTerminator(const OperationInst *op) {
if (!block || &block->back() != op)
return op->emitOpError("must be the last instruction in the parent block");
- // Terminators may not exist in ForInst and IfInst.
+ // TODO(riverriddle) Terminators may not exist with an operation region.
if (block->getContainingInst())
return op->emitOpError("may only be at the top level of a function");
diff --git a/mlir/lib/IR/Value.cpp b/mlir/lib/IR/Value.cpp
index 6418b062dc1..7103eeb7389 100644
--- a/mlir/lib/IR/Value.cpp
+++ b/mlir/lib/IR/Value.cpp
@@ -66,8 +66,6 @@ MLIRContext *IROperandOwner::getContext() const {
return cast<OperationInst>(this)->getContext();
case Kind::ForInst:
return cast<ForInst>(this)->getContext();
- case Kind::IfInst:
- return cast<IfInst>(this)->getContext();
}
}
diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index c477ad1bbc5..e5d6aa46565 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -996,8 +996,7 @@ Attribute Parser::parseAttribute(Type type) {
AffineMap map;
IntegerSet set;
if (parseAffineMapOrIntegerSetReference(map, set))
- return (emitError("expected affine map or integer set attribute value"),
- nullptr);
+ return nullptr;
if (map)
return builder.getAffineMapAttr(map);
assert(set);
@@ -2209,8 +2208,6 @@ public:
const char *affineStructName);
ParseResult parseBound(SmallVectorImpl<Value *> &operands, AffineMap &map,
bool isLower);
- ParseResult parseIfInst();
- ParseResult parseElseClause(Block *elseClause);
ParseResult parseInstructions(Block *block);
private:
@@ -2392,10 +2389,6 @@ ParseResult FunctionParser::parseBlockBody(Block *block) {
if (parseForInst())
return ParseFailure;
break;
- case Token::kw_if:
- if (parseIfInst())
- return ParseFailure;
- break;
}
}
@@ -2935,12 +2928,18 @@ public:
return false;
}
- /// Parse a keyword followed by a type.
- bool parseKeywordType(const char *keyword, Type &result) override {
- if (parser.getTokenSpelling() != keyword)
- return parser.emitError("expected '" + Twine(keyword) + "'");
- parser.consumeToken();
- return !(result = parser.parseType());
+ /// Parse an optional keyword.
+ bool parseOptionalKeyword(const char *keyword) override {
+ // Check that the current token is a bare identifier or keyword.
+ if (parser.getToken().isNot(Token::bare_identifier) &&
+ !parser.getToken().isKeyword())
+ return true;
+
+ if (parser.getTokenSpelling() == keyword) {
+ parser.consumeToken();
+ return false;
+ }
+ return true;
}
/// Parse an arbitrary attribute of a given type and return it in result. This
@@ -3078,6 +3077,15 @@ public:
return result == nullptr;
}
+ /// Parses a list of blocks.
+ bool parseBlockList() override {
+ SmallVector<Block *, 2> results;
+ if (parser.parseOperationBlockList(results))
+ return true;
+ parsedBlockLists.emplace_back(results);
+ return false;
+ }
+
//===--------------------------------------------------------------------===//
// Methods for interacting with the parser
//===--------------------------------------------------------------------===//
@@ -3099,6 +3107,11 @@ public:
/// Emit a diagnostic at the specified location and return true.
bool emitError(llvm::SMLoc loc, const Twine &message) override {
+ // If we emit an error, then cleanup any parsed block lists.
+ for (auto &blockList : parsedBlockLists)
+ parser.cleanupInvalidBlocks(blockList);
+ parsedBlockLists.clear();
+
parser.emitError(loc, "custom op '" + Twine(opName) + "' " + message);
emittedError = true;
return true;
@@ -3106,7 +3119,13 @@ public:
bool didEmitError() const { return emittedError; }
+ /// Returns the block lists that were parsed.
+ MutableArrayRef<SmallVector<Block *, 2>> getParsedBlockLists() {
+ return parsedBlockLists;
+ }
+
private:
+ std::vector<SmallVector<Block *, 2>> parsedBlockLists;
SMLoc nameLoc;
StringRef opName;
FunctionParser &parser;
@@ -3145,8 +3164,25 @@ OperationInst *FunctionParser::parseCustomOperation() {
if (opAsmParser.didEmitError())
return nullptr;
+ // Check that enough block lists were reserved for those that were parsed.
+ auto parsedBlockLists = opAsmParser.getParsedBlockLists();
+ if (parsedBlockLists.size() > opState.numBlockLists) {
+ opAsmParser.emitError(
+ opLoc,
+ "parsed more block lists than those reserved in the operation state");
+ return nullptr;
+ }
+
// Otherwise, we succeeded. Use the state it parsed as our op information.
- return builder.createOperation(opState);
+ auto *opInst = builder.createOperation(opState);
+
+ // Resolve any parsed block lists.
+ for (unsigned i = 0, e = parsedBlockLists.size(); i != e; ++i) {
+ auto &opBlockList = opInst->getBlockList(i).getBlocks();
+ opBlockList.insert(opBlockList.end(), parsedBlockLists[i].begin(),
+ parsedBlockLists[i].end());
+ }
+ return opInst;
}
/// For instruction.
@@ -3438,69 +3474,6 @@ IntegerSet AffineParser::parseIntegerSetConstraints(unsigned numDims,
return builder.getIntegerSet(numDims, numSymbols, constraints, isEqs);
}
-/// If instruction.
-///
-/// ml-if-head ::= `if` ml-if-cond trailing-location? `{` inst* `}`
-/// | ml-if-head `else` `if` ml-if-cond trailing-location?
-/// `{` inst* `}`
-/// ml-if-inst ::= ml-if-head
-/// | ml-if-head `else` `{` inst* `}`
-///
-ParseResult FunctionParser::parseIfInst() {
- auto loc = getToken().getLoc();
- consumeToken(Token::kw_if);
-
- IntegerSet set = parseIntegerSetReference();
- if (!set)
- return ParseFailure;
-
- SmallVector<Value *, 4> operands;
- if (parseDimAndSymbolList(operands, set.getNumDims(), set.getNumOperands(),
- "integer set"))
- return ParseFailure;
-
- IfInst *ifInst =
- builder.createIf(getEncodedSourceLocation(loc), operands, set);
-
- // Try to parse the optional trailing location.
- if (parseOptionalTrailingLocation(ifInst))
- return ParseFailure;
-
- Block *thenClause = ifInst->getThen();
-
- // When parsing of an if instruction body fails, the IR contains
- // the if instruction with the portion of the body that has been
- // successfully parsed.
- if (parseToken(Token::l_brace, "expected '{' before instruction list") ||
- parseBlock(thenClause) ||
- parseToken(Token::r_brace, "expected '}' after instruction list"))
- return ParseFailure;
-
- if (consumeIf(Token::kw_else)) {
- auto *elseClause = ifInst->createElse();
- if (parseElseClause(elseClause))
- return ParseFailure;
- }
-
- // Reset insertion point to the current block.
- builder.setInsertionPointToEnd(ifInst->getBlock());
-
- return ParseSuccess;
-}
-
-ParseResult FunctionParser::parseElseClause(Block *elseClause) {
- if (getToken().is(Token::kw_if)) {
- builder.setInsertionPointToEnd(elseClause);
- return parseIfInst();
- }
-
- if (parseToken(Token::l_brace, "expected '{' before instruction list") ||
- parseBlock(elseClause) ||
- parseToken(Token::r_brace, "expected '}' after instruction list"))
- return ParseFailure;
- return ParseSuccess;
-}
-
//===----------------------------------------------------------------------===//
// Top-level entity parsing.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Parser/TokenKinds.def b/mlir/lib/Parser/TokenKinds.def
index 40e98b25cb3..ec00f98b3f5 100644
--- a/mlir/lib/Parser/TokenKinds.def
+++ b/mlir/lib/Parser/TokenKinds.def
@@ -91,7 +91,6 @@ TOK_KEYWORD(attributes)
TOK_KEYWORD(bf16)
TOK_KEYWORD(ceildiv)
TOK_KEYWORD(dense)
-TOK_KEYWORD(else)
TOK_KEYWORD(splat)
TOK_KEYWORD(f16)
TOK_KEYWORD(f32)
@@ -100,7 +99,6 @@ TOK_KEYWORD(false)
TOK_KEYWORD(floordiv)
TOK_KEYWORD(for)
TOK_KEYWORD(func)
-TOK_KEYWORD(if)
TOK_KEYWORD(index)
TOK_KEYWORD(loc)
TOK_KEYWORD(max)
diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp
index c2e1636626d..afd18a49b79 100644
--- a/mlir/lib/Transforms/CSE.cpp
+++ b/mlir/lib/Transforms/CSE.cpp
@@ -188,16 +188,6 @@ void CSE::simplifyBlock(Block *bb) {
simplifyBlock(cast<ForInst>(i).getBody());
break;
}
- case Instruction::Kind::If: {
- auto &ifInst = cast<IfInst>(i);
- if (auto *elseBlock = ifInst.getElse()) {
- ScopedMapTy::ScopeTy scope(knownValues);
- simplifyBlock(elseBlock);
- }
- ScopedMapTy::ScopeTy scope(knownValues);
- simplifyBlock(ifInst.getThen());
- break;
- }
}
}
}
diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp
index cee0a08a63c..eebbbe9daa7 100644
--- a/mlir/lib/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Transforms/LoopFusion.cpp
@@ -19,6 +19,7 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/AffineOps/AffineOps.h"
#include "mlir/Analysis/AffineAnalysis.h"
#include "mlir/Analysis/AffineStructures.h"
#include "mlir/Analysis/LoopAnalysis.h"
@@ -99,16 +100,16 @@ public:
SmallVector<ForInst *, 4> forInsts;
SmallVector<OperationInst *, 4> loadOpInsts;
SmallVector<OperationInst *, 4> storeOpInsts;
- bool hasIfInst = false;
+ bool hasNonForRegion = false;
void visitForInst(ForInst *forInst) { forInsts.push_back(forInst); }
- void visitIfInst(IfInst *ifInst) { hasIfInst = true; }
-
void visitOperationInst(OperationInst *opInst) {
- if (opInst->isa<LoadOp>())
+ if (opInst->getNumBlockLists() != 0)
+ hasNonForRegion = true;
+ else if (opInst->isa<LoadOp>())
loadOpInsts.push_back(opInst);
- if (opInst->isa<StoreOp>())
+ else if (opInst->isa<StoreOp>())
storeOpInsts.push_back(opInst);
}
};
@@ -410,8 +411,8 @@ bool MemRefDependenceGraph::init(Function *f) {
// all loads and store accesses it contains.
LoopNestStateCollector collector;
collector.walkForInst(forInst);
- // Return false if IfInsts are found (not currently supported).
- if (collector.hasIfInst)
+ // Return false if a non 'for' region was found (not currently supported).
+ if (collector.hasNonForRegion)
return false;
Node node(id++, &inst);
for (auto *opInst : collector.loadOpInsts) {
@@ -434,19 +435,18 @@ bool MemRefDependenceGraph::init(Function *f) {
auto *memref = opInst->cast<LoadOp>()->getMemRef();
memrefAccesses[memref].insert(node.id);
nodes.insert({node.id, node});
- }
- if (auto storeOp = opInst->dyn_cast<StoreOp>()) {
+ } else if (auto storeOp = opInst->dyn_cast<StoreOp>()) {
// Create graph node for top-level store op.
Node node(id++, &inst);
node.stores.push_back(opInst);
auto *memref = opInst->cast<StoreOp>()->getMemRef();
memrefAccesses[memref].insert(node.id);
nodes.insert({node.id, node});
+ } else if (opInst->getNumBlockLists() != 0) {
+ // Return false if another region is found (not currently supported).
+ return false;
}
}
- // Return false if IfInsts are found (not currently supported).
- if (isa<IfInst>(&inst))
- return false;
}
// Walk memref access lists and add graph edges between dependent nodes.
diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp
index 39ef758833b..6d63e4afd2d 100644
--- a/mlir/lib/Transforms/LoopUnroll.cpp
+++ b/mlir/lib/Transforms/LoopUnroll.cpp
@@ -119,15 +119,6 @@ PassResult LoopUnroll::runOnFunction(Function *f) {
return true;
}
- bool walkIfInstPostOrder(IfInst *ifInst) {
- bool hasInnerLoops =
- walkPostOrder(ifInst->getThen()->begin(), ifInst->getThen()->end());
- if (ifInst->hasElse())
- hasInnerLoops |=
- walkPostOrder(ifInst->getElse()->begin(), ifInst->getElse()->end());
- return hasInnerLoops;
- }
-
bool walkOpInstPostOrder(OperationInst *opInst) {
for (auto &blockList : opInst->getBlockLists())
for (auto &block : blockList)
diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp
index ab37ff63bad..f770684f519 100644
--- a/mlir/lib/Transforms/LowerAffine.cpp
+++ b/mlir/lib/Transforms/LowerAffine.cpp
@@ -20,6 +20,7 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/AffineOps/AffineOps.h"
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
@@ -246,7 +247,7 @@ public:
PassResult runOnFunction(Function *function) override;
bool lowerForInst(ForInst *forInst);
- bool lowerIfInst(IfInst *ifInst);
+ bool lowerAffineIf(AffineIfOp *ifOp);
bool lowerAffineApply(AffineApplyOp *op);
static char passID;
@@ -409,7 +410,7 @@ bool LowerAffinePass::lowerForInst(ForInst *forInst) {
// enabling easy nesting of "if" instructions and if-then-else-if chains.
//
// +--------------------------------+
-// | <code before the IfInst> |
+// | <code before the AffineIfOp> |
// | %zero = constant 0 : index |
// | %v = affine_apply #expr1(%ops) |
// | %c = cmpi "sge" %v, %zero |
@@ -453,10 +454,11 @@ bool LowerAffinePass::lowerForInst(ForInst *forInst) {
// v v
// +--------------------------------+
// | continue: |
-// | <code after the IfInst> |
+// | <code after the AffineIfOp> |
// +--------------------------------+
//
-bool LowerAffinePass::lowerIfInst(IfInst *ifInst) {
+bool LowerAffinePass::lowerAffineIf(AffineIfOp *ifOp) {
+ auto *ifInst = ifOp->getInstruction();
auto loc = ifInst->getLoc();
// Start by splitting the block containing the 'if' into two parts. The part
@@ -466,22 +468,38 @@ bool LowerAffinePass::lowerIfInst(IfInst *ifInst) {
auto *continueBlock = condBlock->splitBlock(ifInst);
// Create a block for the 'then' code, inserting it between the cond and
- // continue blocks. Move the instructions over from the IfInst and add a
+ // continue blocks. Move the instructions over from the AffineIfOp and add a
// branch to the continuation point.
Block *thenBlock = new Block();
thenBlock->insertBefore(continueBlock);
- auto *oldThen = ifInst->getThen();
- thenBlock->getInstructions().splice(thenBlock->begin(),
- oldThen->getInstructions(),
- oldThen->begin(), oldThen->end());
+ // If the 'then' block is not empty, then splice the instructions.
+ auto &oldThenBlocks = ifOp->getThenBlocks();
+ if (!oldThenBlocks.empty()) {
+ // We currently only handle one 'then' block.
+ if (std::next(oldThenBlocks.begin()) != oldThenBlocks.end())
+ return true;
+
+ Block *oldThen = &oldThenBlocks.front();
+
+ thenBlock->getInstructions().splice(thenBlock->begin(),
+ oldThen->getInstructions(),
+ oldThen->begin(), oldThen->end());
+ }
+
FuncBuilder builder(thenBlock);
builder.create<BranchOp>(loc, continueBlock);
// Handle the 'else' block the same way, but we skip it if we have no else
// code.
Block *elseBlock = continueBlock;
- if (auto *oldElse = ifInst->getElse()) {
+ auto &oldElseBlocks = ifOp->getElseBlocks();
+ if (!oldElseBlocks.empty()) {
+ // We currently only handle one 'else' block.
+ if (std::next(oldElseBlocks.begin()) != oldElseBlocks.end())
+ return true;
+
+ auto *oldElse = &oldElseBlocks.front();
elseBlock = new Block();
elseBlock->insertBefore(continueBlock);
@@ -493,7 +511,7 @@ bool LowerAffinePass::lowerIfInst(IfInst *ifInst) {
}
// Ok, now we just have to handle the condition logic.
- auto integerSet = ifInst->getCondition().getIntegerSet();
+ auto integerSet = ifOp->getIntegerSet();
// Implement short-circuit logic. For each affine expression in the 'if'
// condition, convert it into an affine map and call `affine_apply` to obtain
@@ -593,29 +611,30 @@ bool LowerAffinePass::lowerAffineApply(AffineApplyOp *op) {
PassResult LowerAffinePass::runOnFunction(Function *function) {
SmallVector<Instruction *, 8> instsToRewrite;
- // Collect all the If and For instructions as well as AffineApplyOps. We do
- // this as a prepass to avoid invalidating the walker with our rewrite.
+ // Collect all the For instructions as well as AffineIfOps and AffineApplyOps.
+ // We do this as a prepass to avoid invalidating the walker with our rewrite.
function->walkInsts([&](Instruction *inst) {
- if (isa<IfInst>(inst) || isa<ForInst>(inst))
+ if (isa<ForInst>(inst))
instsToRewrite.push_back(inst);
auto op = dyn_cast<OperationInst>(inst);
- if (op && op->isa<AffineApplyOp>())
+ if (op && (op->isa<AffineApplyOp>() || op->isa<AffineIfOp>()))
instsToRewrite.push_back(inst);
});
// Rewrite all of the ifs and fors. We walked the instructions in preorder,
// so we know that we will rewrite them in the same order.
for (auto *inst : instsToRewrite)
- if (auto *ifInst = dyn_cast<IfInst>(inst)) {
- if (lowerIfInst(ifInst))
- return failure();
- } else if (auto *forInst = dyn_cast<ForInst>(inst)) {
+ if (auto *forInst = dyn_cast<ForInst>(inst)) {
if (lowerForInst(forInst))
return failure();
} else {
auto op = cast<OperationInst>(inst);
- if (lowerAffineApply(op->cast<AffineApplyOp>()))
+ if (auto ifOp = op->dyn_cast<AffineIfOp>()) {
+ if (lowerAffineIf(ifOp))
+ return failure();
+ } else if (lowerAffineApply(op->cast<AffineApplyOp>())) {
return failure();
+ }
}
return success();
diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp
index 09d961f85cd..2744b1d624c 100644
--- a/mlir/lib/Transforms/MaterializeVectors.cpp
+++ b/mlir/lib/Transforms/MaterializeVectors.cpp
@@ -20,6 +20,7 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/AffineOps/AffineOps.h"
#include "mlir/Analysis/AffineAnalysis.h"
#include "mlir/Analysis/Dominance.h"
#include "mlir/Analysis/LoopAnalysis.h"
@@ -559,9 +560,6 @@ static bool instantiateMaterialization(Instruction *inst,
if (isa<ForInst>(inst))
return inst->emitError("NYI path ForInst");
- if (isa<IfInst>(inst))
- return inst->emitError("NYI path IfInst");
-
// Create a builder here for unroll-and-jam effects.
FuncBuilder b(inst);
auto *opInst = cast<OperationInst>(inst);
@@ -570,6 +568,9 @@ static bool instantiateMaterialization(Instruction *inst,
if (opInst->isa<AffineApplyOp>()) {
return false;
}
+ if (opInst->getNumBlockLists() != 0)
+ return inst->emitError("NYI path Op with region");
+
if (auto write = opInst->dyn_cast<VectorTransferWriteOp>()) {
auto *clone = instantiate(&b, write, state->hwVectorType,
state->hwVectorInstance, state->substitutionsMap);
diff --git a/mlir/lib/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Transforms/SimplifyAffineStructures.cpp
index bd39e47786a..ba59123c700 100644
--- a/mlir/lib/Transforms/SimplifyAffineStructures.cpp
+++ b/mlir/lib/Transforms/SimplifyAffineStructures.cpp
@@ -28,7 +28,6 @@
#define DEBUG_TYPE "simplify-affine-structure"
using namespace mlir;
-using llvm::report_fatal_error;
namespace {
@@ -42,9 +41,6 @@ struct SimplifyAffineStructures : public FunctionPass {
PassResult runOnFunction(Function *f) override;
- void visitIfInst(IfInst *ifInst);
- void visitOperationInst(OperationInst *opInst);
-
static char passID;
};
@@ -66,28 +62,19 @@ static IntegerSet simplifyIntegerSet(IntegerSet set) {
return set;
}
-void SimplifyAffineStructures::visitIfInst(IfInst *ifInst) {
- auto set = ifInst->getCondition().getIntegerSet();
- ifInst->setIntegerSet(simplifyIntegerSet(set));
-}
-
-void SimplifyAffineStructures::visitOperationInst(OperationInst *opInst) {
- for (auto attr : opInst->getAttrs()) {
- if (auto mapAttr = attr.second.dyn_cast<AffineMapAttr>()) {
- MutableAffineMap mMap(mapAttr.getValue());
- mMap.simplify();
- auto map = mMap.getAffineMap();
- opInst->setAttr(attr.first, AffineMapAttr::get(map));
- }
- }
-}
-
PassResult SimplifyAffineStructures::runOnFunction(Function *f) {
- f->walkInsts([&](Instruction *inst) {
- if (auto *opInst = dyn_cast<OperationInst>(inst))
- visitOperationInst(opInst);
- if (auto *ifInst = dyn_cast<IfInst>(inst))
- visitIfInst(ifInst);
+ f->walkOps([&](OperationInst *opInst) {
+ for (auto attr : opInst->getAttrs()) {
+ if (auto mapAttr = attr.second.dyn_cast<AffineMapAttr>()) {
+ MutableAffineMap mMap(mapAttr.getValue());
+ mMap.simplify();
+ auto map = mMap.getAffineMap();
+ opInst->setAttr(attr.first, AffineMapAttr::get(map));
+ } else if (auto setAttr = attr.second.dyn_cast<IntegerSetAttr>()) {
+ auto simplified = simplifyIntegerSet(setAttr.getValue());
+ opInst->setAttr(attr.first, IntegerSetAttr::get(simplified));
+ }
+ }
});
return success();
OpenPOWER on IntegriCloud