diff options
Diffstat (limited to 'mlir/lib')
| -rw-r--r-- | mlir/lib/AffineOps/AffineOps.cpp | 151 | ||||
| -rw-r--r-- | mlir/lib/AffineOps/DialectRegistration.cpp | 22 | ||||
| -rw-r--r-- | mlir/lib/Analysis/LoopAnalysis.cpp | 11 | ||||
| -rw-r--r-- | mlir/lib/Analysis/NestedMatcher.cpp | 20 | ||||
| -rw-r--r-- | mlir/lib/Analysis/Utils.cpp | 22 | ||||
| -rw-r--r-- | mlir/lib/Analysis/Verifier.cpp | 25 | ||||
| -rw-r--r-- | mlir/lib/IR/AsmPrinter.cpp | 38 | ||||
| -rw-r--r-- | mlir/lib/IR/Builders.cpp | 7 | ||||
| -rw-r--r-- | mlir/lib/IR/Instruction.cpp | 109 | ||||
| -rw-r--r-- | mlir/lib/IR/Operation.cpp | 2 | ||||
| -rw-r--r-- | mlir/lib/IR/Value.cpp | 2 | ||||
| -rw-r--r-- | mlir/lib/Parser/Parser.cpp | 129 | ||||
| -rw-r--r-- | mlir/lib/Parser/TokenKinds.def | 2 | ||||
| -rw-r--r-- | mlir/lib/Transforms/CSE.cpp | 10 | ||||
| -rw-r--r-- | mlir/lib/Transforms/LoopFusion.cpp | 24 | ||||
| -rw-r--r-- | mlir/lib/Transforms/LoopUnroll.cpp | 9 | ||||
| -rw-r--r-- | mlir/lib/Transforms/LowerAffine.cpp | 59 | ||||
| -rw-r--r-- | mlir/lib/Transforms/MaterializeVectors.cpp | 7 | ||||
| -rw-r--r-- | mlir/lib/Transforms/SimplifyAffineStructures.cpp | 37 |
19 files changed, 344 insertions, 342 deletions
diff --git a/mlir/lib/AffineOps/AffineOps.cpp b/mlir/lib/AffineOps/AffineOps.cpp new file mode 100644 index 00000000000..5b29467fc44 --- /dev/null +++ b/mlir/lib/AffineOps/AffineOps.cpp @@ -0,0 +1,151 @@ +//===- AffineOps.cpp - MLIR Affine Operations -----------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "mlir/AffineOps/AffineOps.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/OpImplementation.h" +using namespace mlir; + +//===----------------------------------------------------------------------===// +// AffineOpsDialect +//===----------------------------------------------------------------------===// + +AffineOpsDialect::AffineOpsDialect(MLIRContext *context) + : Dialect(/*namePrefix=*/"", context) { + addOperations<AffineIfOp>(); +} + +//===----------------------------------------------------------------------===// +// AffineIfOp +//===----------------------------------------------------------------------===// + +void AffineIfOp::build(Builder *builder, OperationState *result, + IntegerSet condition, + ArrayRef<Value *> conditionOperands) { + result->addAttribute(getConditionAttrName(), IntegerSetAttr::get(condition)); + result->addOperands(conditionOperands); + + // Reserve 2 block lists, one for the 'then' and one for the 'else' regions. + result->reserveBlockLists(2); +} + +bool AffineIfOp::verify() const { + // Verify that we have a condition attribute. + auto conditionAttr = getAttrOfType<IntegerSetAttr>(getConditionAttrName()); + if (!conditionAttr) + return emitOpError("requires an integer set attribute named 'condition'"); + + // Verify that the operands are valid dimension/symbols. + IntegerSet condition = conditionAttr.getValue(); + for (unsigned i = 0, e = getNumOperands(); i != e; ++i) { + const Value *operand = getOperand(i); + if (i < condition.getNumDims() && !operand->isValidDim()) + return emitOpError("operand cannot be used as a dimension id"); + if (i >= condition.getNumDims() && !operand->isValidSymbol()) + return emitOpError("operand cannot be used as a symbol"); + } + + // Verify that the entry of each child blocklist does not have arguments. + for (const auto &blockList : getInstruction()->getBlockLists()) { + if (blockList.empty()) + continue; + + // TODO(riverriddle) We currently do not allow multiple blocks in child + // block lists. + if (std::next(blockList.begin()) != blockList.end()) + return emitOpError( + "expects only one block per 'if' or 'else' block list"); + if (blockList.front().getTerminator()) + return emitOpError("expects region block to not have a terminator"); + + for (const auto &b : blockList) + if (b.getNumArguments() != 0) + return emitOpError( + "requires that child entry blocks have no arguments"); + } + return false; +} + +bool AffineIfOp::parse(OpAsmParser *parser, OperationState *result) { + // Parse the condition attribute set. + IntegerSetAttr conditionAttr; + unsigned numDims; + if (parser->parseAttribute(conditionAttr, getConditionAttrName().data(), + result->attributes) || + parseDimAndSymbolList(parser, result->operands, numDims)) + return true; + + // Verify the condition operands. + auto set = conditionAttr.getValue(); + if (set.getNumDims() != numDims) + return parser->emitError( + parser->getNameLoc(), + "dim operand count and integer set dim count must match"); + if (numDims + set.getNumSymbols() != result->operands.size()) + return parser->emitError( + parser->getNameLoc(), + "symbol operand count and integer set symbol count must match"); + + // Parse the 'then' block list. + if (parser->parseBlockList()) + return true; + + // If we find an 'else' keyword then parse the else block list. + if (!parser->parseOptionalKeyword("else")) { + if (parser->parseBlockList()) + return true; + } + + // Reserve 2 block lists, one for the 'then' and one for the 'else' regions. + result->reserveBlockLists(2); + return false; +} + +void AffineIfOp::print(OpAsmPrinter *p) const { + auto conditionAttr = getAttrOfType<IntegerSetAttr>(getConditionAttrName()); + *p << "if " << conditionAttr; + printDimAndSymbolList(operand_begin(), operand_end(), + conditionAttr.getValue().getNumDims(), p); + p->printBlockList(getInstruction()->getBlockList(0)); + + // Print the 'else' block list if it has any blocks. + const auto &elseBlockList = getInstruction()->getBlockList(1); + if (!elseBlockList.empty()) { + *p << " else"; + p->printBlockList(elseBlockList); + } +} + +IntegerSet AffineIfOp::getIntegerSet() const { + return getAttrOfType<IntegerSetAttr>(getConditionAttrName()).getValue(); +} +void AffineIfOp::setIntegerSet(IntegerSet newSet) { + setAttr( + Identifier::get(getConditionAttrName(), getInstruction()->getContext()), + IntegerSetAttr::get(newSet)); +} + +/// Returns the list of 'then' blocks. +BlockList &AffineIfOp::getThenBlocks() { + return getInstruction()->getBlockList(0); +} + +/// Returns the list of 'else' blocks. +BlockList &AffineIfOp::getElseBlocks() { + return getInstruction()->getBlockList(1); +} diff --git a/mlir/lib/AffineOps/DialectRegistration.cpp b/mlir/lib/AffineOps/DialectRegistration.cpp new file mode 100644 index 00000000000..0afb32c1bd6 --- /dev/null +++ b/mlir/lib/AffineOps/DialectRegistration.cpp @@ -0,0 +1,22 @@ +//===- DialectRegistration.cpp - Register Affine Op dialect ---------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "mlir/AffineOps/AffineOps.h" +using namespace mlir; + +// Static initialization for Affine op dialect registration. +static DialectRegistration<AffineOpsDialect> StandardOps; diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp index 219f356807a..07c903a6613 100644 --- a/mlir/lib/Analysis/LoopAnalysis.cpp +++ b/mlir/lib/Analysis/LoopAnalysis.cpp @@ -21,6 +21,7 @@ #include "mlir/Analysis/LoopAnalysis.h" +#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/NestedMatcher.h" @@ -246,6 +247,16 @@ static bool isVectorizableLoopWithCond(const ForInst &loop, return false; } + // No vectorization across unknown regions. + auto regions = matcher::Op([](const Instruction &inst) -> bool { + auto &opInst = cast<OperationInst>(inst); + return opInst.getNumBlockLists() != 0 && !opInst.isa<AffineIfOp>(); + }); + auto regionsMatched = regions.match(forInst); + if (!regionsMatched.empty()) { + return false; + } + auto vectorTransfers = matcher::Op(isVectorTransferReadOrWrite); auto vectorTransfersMatched = vectorTransfers.match(forInst); if (!vectorTransfersMatched.empty()) { diff --git a/mlir/lib/Analysis/NestedMatcher.cpp b/mlir/lib/Analysis/NestedMatcher.cpp index 4f32e9b22f4..491a9bef1b9 100644 --- a/mlir/lib/Analysis/NestedMatcher.cpp +++ b/mlir/lib/Analysis/NestedMatcher.cpp @@ -16,6 +16,7 @@ // ============================================================================= #include "mlir/Analysis/NestedMatcher.h" +#include "mlir/AffineOps/AffineOps.h" #include "mlir/StandardOps/StandardOps.h" #include "llvm/ADT/ArrayRef.h" @@ -186,6 +187,11 @@ FilterFunctionType NestedPattern::getFilterFunction() { return storage->filter; } +static bool isAffineIfOp(const Instruction &inst) { + return isa<OperationInst>(inst) && + cast<OperationInst>(inst).isa<AffineIfOp>(); +} + namespace mlir { namespace matcher { @@ -194,16 +200,22 @@ NestedPattern Op(FilterFunctionType filter) { } NestedPattern If(NestedPattern child) { - return NestedPattern(Instruction::Kind::If, child, defaultFilterFunction); + return NestedPattern(Instruction::Kind::OperationInst, child, isAffineIfOp); } NestedPattern If(FilterFunctionType filter, NestedPattern child) { - return NestedPattern(Instruction::Kind::If, child, filter); + return NestedPattern(Instruction::Kind::OperationInst, child, + [filter](const Instruction &inst) { + return isAffineIfOp(inst) && filter(inst); + }); } NestedPattern If(ArrayRef<NestedPattern> nested) { - return NestedPattern(Instruction::Kind::If, nested, defaultFilterFunction); + return NestedPattern(Instruction::Kind::OperationInst, nested, isAffineIfOp); } NestedPattern If(FilterFunctionType filter, ArrayRef<NestedPattern> nested) { - return NestedPattern(Instruction::Kind::If, nested, filter); + return NestedPattern(Instruction::Kind::OperationInst, nested, + [filter](const Instruction &inst) { + return isAffineIfOp(inst) && filter(inst); + }); } NestedPattern For(NestedPattern child) { diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index 939a2ede618..0e77d4d9084 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -22,6 +22,7 @@ #include "mlir/Analysis/Utils.h" +#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/AffineStructures.h" #include "mlir/IR/Builders.h" @@ -43,7 +44,7 @@ void mlir::getLoopIVs(const Instruction &inst, // Traverse up the hierarchy collecing all 'for' instruction while skipping // over 'if' instructions. while (currInst && ((currForInst = dyn_cast<ForInst>(currInst)) || - isa<IfInst>(currInst))) { + cast<OperationInst>(currInst)->isa<AffineIfOp>())) { if (currForInst) loops->push_back(currForInst); currInst = currInst->getParentInst(); @@ -359,21 +360,12 @@ static Instruction *getInstAtPosition(ArrayRef<unsigned> positions, if (auto *childForInst = dyn_cast<ForInst>(&inst)) return getInstAtPosition(positions, level + 1, childForInst->getBody()); - if (auto *ifInst = dyn_cast<IfInst>(&inst)) { - auto *ret = getInstAtPosition(positions, level + 1, ifInst->getThen()); - if (ret != nullptr) - return ret; - if (auto *elseClause = ifInst->getElse()) - return getInstAtPosition(positions, level + 1, elseClause); - } - if (auto *opInst = dyn_cast<OperationInst>(&inst)) { - for (auto &blockList : opInst->getBlockLists()) { - for (auto &b : blockList) - if (auto *ret = getInstAtPosition(positions, level + 1, &b)) - return ret; - } - return nullptr; + for (auto &blockList : cast<OperationInst>(&inst)->getBlockLists()) { + for (auto &b : blockList) + if (auto *ret = getInstAtPosition(positions, level + 1, &b)) + return ret; } + return nullptr; } return nullptr; } diff --git a/mlir/lib/Analysis/Verifier.cpp b/mlir/lib/Analysis/Verifier.cpp index 383a4878c35..474eeb2a28e 100644 --- a/mlir/lib/Analysis/Verifier.cpp +++ b/mlir/lib/Analysis/Verifier.cpp @@ -73,7 +73,6 @@ public: bool verifyBlock(const Block &block, bool isTopLevel); bool verifyOperation(const OperationInst &op); bool verifyForInst(const ForInst &forInst); - bool verifyIfInst(const IfInst &ifInst); bool verifyDominance(const Block &block); bool verifyInstDominance(const Instruction &inst); @@ -180,10 +179,6 @@ bool FuncVerifier::verifyBlock(const Block &block, bool isTopLevel) { if (verifyForInst(cast<ForInst>(inst))) return true; break; - case Instruction::Kind::If: - if (verifyIfInst(cast<IfInst>(inst))) - return true; - break; } } @@ -250,18 +245,6 @@ bool FuncVerifier::verifyForInst(const ForInst &forInst) { return verifyBlock(*forInst.getBody(), /*isTopLevel=*/false); } -bool FuncVerifier::verifyIfInst(const IfInst &ifInst) { - // TODO: check that if conditions are properly formed. - if (verifyBlock(*ifInst.getThen(), /*isTopLevel*/ false)) - return true; - - if (auto *elseClause = ifInst.getElse()) - if (verifyBlock(*elseClause, /*isTopLevel*/ false)) - return true; - - return false; -} - bool FuncVerifier::verifyDominance(const Block &block) { for (auto &inst : block) { // Check that all operands on the instruction are ok. @@ -283,14 +266,6 @@ bool FuncVerifier::verifyDominance(const Block &block) { if (verifyDominance(*cast<ForInst>(inst).getBody())) return true; break; - case Instruction::Kind::If: - auto &ifInst = cast<IfInst>(inst); - if (verifyDominance(*ifInst.getThen())) - return true; - if (auto *elseClause = ifInst.getElse()) - if (verifyDominance(*elseClause)) - return true; - break; } } return false; diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 21bc3b824b1..cb4c1f0edce 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -145,7 +145,6 @@ private: // Visit functions. void visitInstruction(const Instruction *inst); void visitForInst(const ForInst *forInst); - void visitIfInst(const IfInst *ifInst); void visitOperationInst(const OperationInst *opInst); void visitType(Type type); void visitAttribute(Attribute attr); @@ -197,10 +196,6 @@ void ModuleState::visitAttribute(Attribute attr) { } } -void ModuleState::visitIfInst(const IfInst *ifInst) { - recordIntegerSetReference(ifInst->getIntegerSet()); -} - void ModuleState::visitForInst(const ForInst *forInst) { AffineMap lbMap = forInst->getLowerBoundMap(); if (!hasCustomForm(lbMap)) @@ -225,8 +220,6 @@ void ModuleState::visitOperationInst(const OperationInst *op) { void ModuleState::visitInstruction(const Instruction *inst) { switch (inst->getKind()) { - case Instruction::Kind::If: - return visitIfInst(cast<IfInst>(inst)); case Instruction::Kind::For: return visitForInst(cast<ForInst>(inst)); case Instruction::Kind::OperationInst: @@ -1077,7 +1070,6 @@ public: void print(const Instruction *inst); void print(const OperationInst *inst); void print(const ForInst *inst); - void print(const IfInst *inst); void print(const Block *block, bool printBlockArgs = true); void printOperation(const OperationInst *op); @@ -1125,6 +1117,9 @@ public: unsigned index) override; /// Print a block list. + void printBlockList(const BlockList &blocks) override { + printBlockList(blocks, /*printEntryBlockArgs=*/true); + } void printBlockList(const BlockList &blocks, bool printEntryBlockArgs) { os << " {\n"; if (!blocks.empty()) { @@ -1214,12 +1209,6 @@ void FunctionPrinter::numberValuesInBlock(const Block &block) { // Recursively number the stuff in the body. numberValuesInBlock(*cast<ForInst>(&inst)->getBody()); break; - case Instruction::Kind::If: { - auto *ifInst = cast<IfInst>(&inst); - numberValuesInBlock(*ifInst->getThen()); - if (auto *elseBlock = ifInst->getElse()) - numberValuesInBlock(*elseBlock); - } } } } @@ -1360,8 +1349,7 @@ void FunctionPrinter::printFunctionSignature() { } void FunctionPrinter::print(const Block *block, bool printBlockArgs) { - // Print the block label and argument list, unless this is the first block of - // the function, or the first block of an IfInst/ForInst with no arguments. + // Print the block label and argument list if requested. if (printBlockArgs) { os.indent(currentIndent); printBlockName(block); @@ -1418,8 +1406,6 @@ void FunctionPrinter::print(const Instruction *inst) { return print(cast<OperationInst>(inst)); case Instruction::Kind::For: return print(cast<ForInst>(inst)); - case Instruction::Kind::If: - return print(cast<IfInst>(inst)); } } @@ -1447,22 +1433,6 @@ void FunctionPrinter::print(const ForInst *inst) { os.indent(currentIndent) << "}"; } -void FunctionPrinter::print(const IfInst *inst) { - os.indent(currentIndent) << "if "; - IntegerSet set = inst->getIntegerSet(); - printIntegerSetReference(set); - printDimAndSymbolList(inst->getInstOperands(), set.getNumDims()); - printTrailingLocation(inst->getLoc()); - os << " {\n"; - print(inst->getThen(), /*printBlockArgs=*/false); - os.indent(currentIndent) << "}"; - if (inst->hasElse()) { - os << " else {\n"; - print(inst->getElse(), /*printBlockArgs=*/false); - os.indent(currentIndent) << "}"; - } -} - void FunctionPrinter::printValueID(const Value *value, bool printResultNo) const { int resultNo = -1; diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index 4471ff25e94..e174fdc1d00 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -327,10 +327,3 @@ ForInst *FuncBuilder::createFor(Location location, int64_t lb, int64_t ub, auto ubMap = AffineMap::getConstantMap(ub, context); return createFor(location, {}, lbMap, {}, ubMap, step); } - -IfInst *FuncBuilder::createIf(Location location, ArrayRef<Value *> operands, - IntegerSet set) { - auto *inst = IfInst::create(location, operands, set); - block->getInstructions().insert(insertPoint, inst); - return inst; -} diff --git a/mlir/lib/IR/Instruction.cpp b/mlir/lib/IR/Instruction.cpp index 6d74ed14257..0ccab2305ec 100644 --- a/mlir/lib/IR/Instruction.cpp +++ b/mlir/lib/IR/Instruction.cpp @@ -73,9 +73,6 @@ void Instruction::destroy() { case Kind::For: delete cast<ForInst>(this); break; - case Kind::If: - delete cast<IfInst>(this); - break; } } @@ -141,8 +138,6 @@ unsigned Instruction::getNumOperands() const { return cast<OperationInst>(this)->getNumOperands(); case Kind::For: return cast<ForInst>(this)->getNumOperands(); - case Kind::If: - return cast<IfInst>(this)->getNumOperands(); } } @@ -152,8 +147,6 @@ MutableArrayRef<InstOperand> Instruction::getInstOperands() { return cast<OperationInst>(this)->getInstOperands(); case Kind::For: return cast<ForInst>(this)->getInstOperands(); - case Kind::If: - return cast<IfInst>(this)->getInstOperands(); } } @@ -287,15 +280,6 @@ void Instruction::dropAllReferences() { // Make sure to drop references held by instructions within the body. cast<ForInst>(this)->getBody()->dropAllReferences(); break; - case Kind::If: { - // Make sure to drop references held by instructions within the 'then' and - // 'else' blocks. - auto *ifInst = cast<IfInst>(this); - ifInst->getThen()->dropAllReferences(); - if (auto *elseBlock = ifInst->getElse()) - elseBlock->dropAllReferences(); - break; - } case Kind::OperationInst: { auto *opInst = cast<OperationInst>(this); if (isTerminator()) @@ -810,54 +794,6 @@ mlir::extractForInductionVars(ArrayRef<ForInst *> forInsts) { return results; } //===----------------------------------------------------------------------===// -// IfInst -//===----------------------------------------------------------------------===// - -IfInst::IfInst(Location location, unsigned numOperands, IntegerSet set) - : Instruction(Kind::If, location), thenClause(this), elseClause(nullptr), - set(set) { - operands.reserve(numOperands); - - // The then of an 'if' inst always has one block. - thenClause.push_back(new Block()); -} - -IfInst::~IfInst() { - if (elseClause) - delete elseClause; - - // An IfInst's IntegerSet 'set' should not be deleted since it is - // allocated through MLIRContext's bump pointer allocator. -} - -IfInst *IfInst::create(Location location, ArrayRef<Value *> operands, - IntegerSet set) { - unsigned numOperands = operands.size(); - assert(numOperands == set.getNumOperands() && - "operand cound does not match the integer set operand count"); - - IfInst *inst = new IfInst(location, numOperands, set); - - for (auto *op : operands) - inst->operands.emplace_back(InstOperand(inst, op)); - - return inst; -} - -const AffineCondition IfInst::getCondition() const { - return AffineCondition(*this, set); -} - -MLIRContext *IfInst::getContext() const { - // Check for degenerate case of if instruction with no operands. - // This is unlikely, but legal. - if (operands.empty()) - return getFunction()->getContext(); - - return getOperand(0)->getType().getContext(); -} - -//===----------------------------------------------------------------------===// // Instruction Cloning //===----------------------------------------------------------------------===// @@ -931,40 +867,23 @@ Instruction *Instruction::clone(BlockAndValueMapping &mapper, for (auto *opValue : getOperands()) operands.push_back(mapper.lookupOrDefault(const_cast<Value *>(opValue))); - if (auto *forInst = dyn_cast<ForInst>(this)) { - auto lbMap = forInst->getLowerBoundMap(); - auto ubMap = forInst->getUpperBoundMap(); + // Otherwise, this must be a ForInst. + auto *forInst = cast<ForInst>(this); + auto lbMap = forInst->getLowerBoundMap(); + auto ubMap = forInst->getUpperBoundMap(); - auto *newFor = ForInst::create( - getLoc(), ArrayRef<Value *>(operands).take_front(lbMap.getNumInputs()), - lbMap, ArrayRef<Value *>(operands).take_back(ubMap.getNumInputs()), - ubMap, forInst->getStep()); + auto *newFor = ForInst::create( + getLoc(), ArrayRef<Value *>(operands).take_front(lbMap.getNumInputs()), + lbMap, ArrayRef<Value *>(operands).take_back(ubMap.getNumInputs()), ubMap, + forInst->getStep()); - // Remember the induction variable mapping. - mapper.map(forInst->getInductionVar(), newFor->getInductionVar()); - - // Recursively clone the body of the for loop. - for (auto &subInst : *forInst->getBody()) - newFor->getBody()->push_back(subInst.clone(mapper, context)); - - return newFor; - } - - // Otherwise, we must have an If instruction. - auto *ifInst = cast<IfInst>(this); - auto *newIf = IfInst::create(getLoc(), operands, ifInst->getIntegerSet()); - - auto *resultThen = newIf->getThen(); - for (auto &childInst : *ifInst->getThen()) - resultThen->push_back(childInst.clone(mapper, context)); - - if (ifInst->hasElse()) { - auto *resultElse = newIf->createElse(); - for (auto &childInst : *ifInst->getElse()) - resultElse->push_back(childInst.clone(mapper, context)); - } + // Remember the induction variable mapping. + mapper.map(forInst->getInductionVar(), newFor->getInductionVar()); - return newIf; + // Recursively clone the body of the for loop. + for (auto &subInst : *forInst->getBody()) + newFor->getBody()->push_back(subInst.clone(mapper, context)); + return newFor; } Instruction *Instruction::clone(MLIRContext *context) const { diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index 099b218892f..2ab151f8913 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -281,7 +281,7 @@ bool OpTrait::impl::verifyIsTerminator(const OperationInst *op) { if (!block || &block->back() != op) return op->emitOpError("must be the last instruction in the parent block"); - // Terminators may not exist in ForInst and IfInst. + // TODO(riverriddle) Terminators may not exist with an operation region. if (block->getContainingInst()) return op->emitOpError("may only be at the top level of a function"); diff --git a/mlir/lib/IR/Value.cpp b/mlir/lib/IR/Value.cpp index 6418b062dc1..7103eeb7389 100644 --- a/mlir/lib/IR/Value.cpp +++ b/mlir/lib/IR/Value.cpp @@ -66,8 +66,6 @@ MLIRContext *IROperandOwner::getContext() const { return cast<OperationInst>(this)->getContext(); case Kind::ForInst: return cast<ForInst>(this)->getContext(); - case Kind::IfInst: - return cast<IfInst>(this)->getContext(); } } diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index c477ad1bbc5..e5d6aa46565 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -996,8 +996,7 @@ Attribute Parser::parseAttribute(Type type) { AffineMap map; IntegerSet set; if (parseAffineMapOrIntegerSetReference(map, set)) - return (emitError("expected affine map or integer set attribute value"), - nullptr); + return nullptr; if (map) return builder.getAffineMapAttr(map); assert(set); @@ -2209,8 +2208,6 @@ public: const char *affineStructName); ParseResult parseBound(SmallVectorImpl<Value *> &operands, AffineMap &map, bool isLower); - ParseResult parseIfInst(); - ParseResult parseElseClause(Block *elseClause); ParseResult parseInstructions(Block *block); private: @@ -2392,10 +2389,6 @@ ParseResult FunctionParser::parseBlockBody(Block *block) { if (parseForInst()) return ParseFailure; break; - case Token::kw_if: - if (parseIfInst()) - return ParseFailure; - break; } } @@ -2935,12 +2928,18 @@ public: return false; } - /// Parse a keyword followed by a type. - bool parseKeywordType(const char *keyword, Type &result) override { - if (parser.getTokenSpelling() != keyword) - return parser.emitError("expected '" + Twine(keyword) + "'"); - parser.consumeToken(); - return !(result = parser.parseType()); + /// Parse an optional keyword. + bool parseOptionalKeyword(const char *keyword) override { + // Check that the current token is a bare identifier or keyword. + if (parser.getToken().isNot(Token::bare_identifier) && + !parser.getToken().isKeyword()) + return true; + + if (parser.getTokenSpelling() == keyword) { + parser.consumeToken(); + return false; + } + return true; } /// Parse an arbitrary attribute of a given type and return it in result. This @@ -3078,6 +3077,15 @@ public: return result == nullptr; } + /// Parses a list of blocks. + bool parseBlockList() override { + SmallVector<Block *, 2> results; + if (parser.parseOperationBlockList(results)) + return true; + parsedBlockLists.emplace_back(results); + return false; + } + //===--------------------------------------------------------------------===// // Methods for interacting with the parser //===--------------------------------------------------------------------===// @@ -3099,6 +3107,11 @@ public: /// Emit a diagnostic at the specified location and return true. bool emitError(llvm::SMLoc loc, const Twine &message) override { + // If we emit an error, then cleanup any parsed block lists. + for (auto &blockList : parsedBlockLists) + parser.cleanupInvalidBlocks(blockList); + parsedBlockLists.clear(); + parser.emitError(loc, "custom op '" + Twine(opName) + "' " + message); emittedError = true; return true; @@ -3106,7 +3119,13 @@ public: bool didEmitError() const { return emittedError; } + /// Returns the block lists that were parsed. + MutableArrayRef<SmallVector<Block *, 2>> getParsedBlockLists() { + return parsedBlockLists; + } + private: + std::vector<SmallVector<Block *, 2>> parsedBlockLists; SMLoc nameLoc; StringRef opName; FunctionParser &parser; @@ -3145,8 +3164,25 @@ OperationInst *FunctionParser::parseCustomOperation() { if (opAsmParser.didEmitError()) return nullptr; + // Check that enough block lists were reserved for those that were parsed. + auto parsedBlockLists = opAsmParser.getParsedBlockLists(); + if (parsedBlockLists.size() > opState.numBlockLists) { + opAsmParser.emitError( + opLoc, + "parsed more block lists than those reserved in the operation state"); + return nullptr; + } + // Otherwise, we succeeded. Use the state it parsed as our op information. - return builder.createOperation(opState); + auto *opInst = builder.createOperation(opState); + + // Resolve any parsed block lists. + for (unsigned i = 0, e = parsedBlockLists.size(); i != e; ++i) { + auto &opBlockList = opInst->getBlockList(i).getBlocks(); + opBlockList.insert(opBlockList.end(), parsedBlockLists[i].begin(), + parsedBlockLists[i].end()); + } + return opInst; } /// For instruction. @@ -3438,69 +3474,6 @@ IntegerSet AffineParser::parseIntegerSetConstraints(unsigned numDims, return builder.getIntegerSet(numDims, numSymbols, constraints, isEqs); } -/// If instruction. -/// -/// ml-if-head ::= `if` ml-if-cond trailing-location? `{` inst* `}` -/// | ml-if-head `else` `if` ml-if-cond trailing-location? -/// `{` inst* `}` -/// ml-if-inst ::= ml-if-head -/// | ml-if-head `else` `{` inst* `}` -/// -ParseResult FunctionParser::parseIfInst() { - auto loc = getToken().getLoc(); - consumeToken(Token::kw_if); - - IntegerSet set = parseIntegerSetReference(); - if (!set) - return ParseFailure; - - SmallVector<Value *, 4> operands; - if (parseDimAndSymbolList(operands, set.getNumDims(), set.getNumOperands(), - "integer set")) - return ParseFailure; - - IfInst *ifInst = - builder.createIf(getEncodedSourceLocation(loc), operands, set); - - // Try to parse the optional trailing location. - if (parseOptionalTrailingLocation(ifInst)) - return ParseFailure; - - Block *thenClause = ifInst->getThen(); - - // When parsing of an if instruction body fails, the IR contains - // the if instruction with the portion of the body that has been - // successfully parsed. - if (parseToken(Token::l_brace, "expected '{' before instruction list") || - parseBlock(thenClause) || - parseToken(Token::r_brace, "expected '}' after instruction list")) - return ParseFailure; - - if (consumeIf(Token::kw_else)) { - auto *elseClause = ifInst->createElse(); - if (parseElseClause(elseClause)) - return ParseFailure; - } - - // Reset insertion point to the current block. - builder.setInsertionPointToEnd(ifInst->getBlock()); - - return ParseSuccess; -} - -ParseResult FunctionParser::parseElseClause(Block *elseClause) { - if (getToken().is(Token::kw_if)) { - builder.setInsertionPointToEnd(elseClause); - return parseIfInst(); - } - - if (parseToken(Token::l_brace, "expected '{' before instruction list") || - parseBlock(elseClause) || - parseToken(Token::r_brace, "expected '}' after instruction list")) - return ParseFailure; - return ParseSuccess; -} - //===----------------------------------------------------------------------===// // Top-level entity parsing. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Parser/TokenKinds.def b/mlir/lib/Parser/TokenKinds.def index 40e98b25cb3..ec00f98b3f5 100644 --- a/mlir/lib/Parser/TokenKinds.def +++ b/mlir/lib/Parser/TokenKinds.def @@ -91,7 +91,6 @@ TOK_KEYWORD(attributes) TOK_KEYWORD(bf16) TOK_KEYWORD(ceildiv) TOK_KEYWORD(dense) -TOK_KEYWORD(else) TOK_KEYWORD(splat) TOK_KEYWORD(f16) TOK_KEYWORD(f32) @@ -100,7 +99,6 @@ TOK_KEYWORD(false) TOK_KEYWORD(floordiv) TOK_KEYWORD(for) TOK_KEYWORD(func) -TOK_KEYWORD(if) TOK_KEYWORD(index) TOK_KEYWORD(loc) TOK_KEYWORD(max) diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp index c2e1636626d..afd18a49b79 100644 --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -188,16 +188,6 @@ void CSE::simplifyBlock(Block *bb) { simplifyBlock(cast<ForInst>(i).getBody()); break; } - case Instruction::Kind::If: { - auto &ifInst = cast<IfInst>(i); - if (auto *elseBlock = ifInst.getElse()) { - ScopedMapTy::ScopeTy scope(knownValues); - simplifyBlock(elseBlock); - } - ScopedMapTy::ScopeTy scope(knownValues); - simplifyBlock(ifInst.getThen()); - break; - } } } } diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index cee0a08a63c..eebbbe9daa7 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -19,6 +19,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/LoopAnalysis.h" @@ -99,16 +100,16 @@ public: SmallVector<ForInst *, 4> forInsts; SmallVector<OperationInst *, 4> loadOpInsts; SmallVector<OperationInst *, 4> storeOpInsts; - bool hasIfInst = false; + bool hasNonForRegion = false; void visitForInst(ForInst *forInst) { forInsts.push_back(forInst); } - void visitIfInst(IfInst *ifInst) { hasIfInst = true; } - void visitOperationInst(OperationInst *opInst) { - if (opInst->isa<LoadOp>()) + if (opInst->getNumBlockLists() != 0) + hasNonForRegion = true; + else if (opInst->isa<LoadOp>()) loadOpInsts.push_back(opInst); - if (opInst->isa<StoreOp>()) + else if (opInst->isa<StoreOp>()) storeOpInsts.push_back(opInst); } }; @@ -410,8 +411,8 @@ bool MemRefDependenceGraph::init(Function *f) { // all loads and store accesses it contains. LoopNestStateCollector collector; collector.walkForInst(forInst); - // Return false if IfInsts are found (not currently supported). - if (collector.hasIfInst) + // Return false if a non 'for' region was found (not currently supported). + if (collector.hasNonForRegion) return false; Node node(id++, &inst); for (auto *opInst : collector.loadOpInsts) { @@ -434,19 +435,18 @@ bool MemRefDependenceGraph::init(Function *f) { auto *memref = opInst->cast<LoadOp>()->getMemRef(); memrefAccesses[memref].insert(node.id); nodes.insert({node.id, node}); - } - if (auto storeOp = opInst->dyn_cast<StoreOp>()) { + } else if (auto storeOp = opInst->dyn_cast<StoreOp>()) { // Create graph node for top-level store op. Node node(id++, &inst); node.stores.push_back(opInst); auto *memref = opInst->cast<StoreOp>()->getMemRef(); memrefAccesses[memref].insert(node.id); nodes.insert({node.id, node}); + } else if (opInst->getNumBlockLists() != 0) { + // Return false if another region is found (not currently supported). + return false; } } - // Return false if IfInsts are found (not currently supported). - if (isa<IfInst>(&inst)) - return false; } // Walk memref access lists and add graph edges between dependent nodes. diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp index 39ef758833b..6d63e4afd2d 100644 --- a/mlir/lib/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Transforms/LoopUnroll.cpp @@ -119,15 +119,6 @@ PassResult LoopUnroll::runOnFunction(Function *f) { return true; } - bool walkIfInstPostOrder(IfInst *ifInst) { - bool hasInnerLoops = - walkPostOrder(ifInst->getThen()->begin(), ifInst->getThen()->end()); - if (ifInst->hasElse()) - hasInnerLoops |= - walkPostOrder(ifInst->getElse()->begin(), ifInst->getElse()->end()); - return hasInnerLoops; - } - bool walkOpInstPostOrder(OperationInst *opInst) { for (auto &blockList : opInst->getBlockLists()) for (auto &block : blockList) diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp index ab37ff63bad..f770684f519 100644 --- a/mlir/lib/Transforms/LowerAffine.cpp +++ b/mlir/lib/Transforms/LowerAffine.cpp @@ -20,6 +20,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/AffineOps/AffineOps.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" @@ -246,7 +247,7 @@ public: PassResult runOnFunction(Function *function) override; bool lowerForInst(ForInst *forInst); - bool lowerIfInst(IfInst *ifInst); + bool lowerAffineIf(AffineIfOp *ifOp); bool lowerAffineApply(AffineApplyOp *op); static char passID; @@ -409,7 +410,7 @@ bool LowerAffinePass::lowerForInst(ForInst *forInst) { // enabling easy nesting of "if" instructions and if-then-else-if chains. // // +--------------------------------+ -// | <code before the IfInst> | +// | <code before the AffineIfOp> | // | %zero = constant 0 : index | // | %v = affine_apply #expr1(%ops) | // | %c = cmpi "sge" %v, %zero | @@ -453,10 +454,11 @@ bool LowerAffinePass::lowerForInst(ForInst *forInst) { // v v // +--------------------------------+ // | continue: | -// | <code after the IfInst> | +// | <code after the AffineIfOp> | // +--------------------------------+ // -bool LowerAffinePass::lowerIfInst(IfInst *ifInst) { +bool LowerAffinePass::lowerAffineIf(AffineIfOp *ifOp) { + auto *ifInst = ifOp->getInstruction(); auto loc = ifInst->getLoc(); // Start by splitting the block containing the 'if' into two parts. The part @@ -466,22 +468,38 @@ bool LowerAffinePass::lowerIfInst(IfInst *ifInst) { auto *continueBlock = condBlock->splitBlock(ifInst); // Create a block for the 'then' code, inserting it between the cond and - // continue blocks. Move the instructions over from the IfInst and add a + // continue blocks. Move the instructions over from the AffineIfOp and add a // branch to the continuation point. Block *thenBlock = new Block(); thenBlock->insertBefore(continueBlock); - auto *oldThen = ifInst->getThen(); - thenBlock->getInstructions().splice(thenBlock->begin(), - oldThen->getInstructions(), - oldThen->begin(), oldThen->end()); + // If the 'then' block is not empty, then splice the instructions. + auto &oldThenBlocks = ifOp->getThenBlocks(); + if (!oldThenBlocks.empty()) { + // We currently only handle one 'then' block. + if (std::next(oldThenBlocks.begin()) != oldThenBlocks.end()) + return true; + + Block *oldThen = &oldThenBlocks.front(); + + thenBlock->getInstructions().splice(thenBlock->begin(), + oldThen->getInstructions(), + oldThen->begin(), oldThen->end()); + } + FuncBuilder builder(thenBlock); builder.create<BranchOp>(loc, continueBlock); // Handle the 'else' block the same way, but we skip it if we have no else // code. Block *elseBlock = continueBlock; - if (auto *oldElse = ifInst->getElse()) { + auto &oldElseBlocks = ifOp->getElseBlocks(); + if (!oldElseBlocks.empty()) { + // We currently only handle one 'else' block. + if (std::next(oldElseBlocks.begin()) != oldElseBlocks.end()) + return true; + + auto *oldElse = &oldElseBlocks.front(); elseBlock = new Block(); elseBlock->insertBefore(continueBlock); @@ -493,7 +511,7 @@ bool LowerAffinePass::lowerIfInst(IfInst *ifInst) { } // Ok, now we just have to handle the condition logic. - auto integerSet = ifInst->getCondition().getIntegerSet(); + auto integerSet = ifOp->getIntegerSet(); // Implement short-circuit logic. For each affine expression in the 'if' // condition, convert it into an affine map and call `affine_apply` to obtain @@ -593,29 +611,30 @@ bool LowerAffinePass::lowerAffineApply(AffineApplyOp *op) { PassResult LowerAffinePass::runOnFunction(Function *function) { SmallVector<Instruction *, 8> instsToRewrite; - // Collect all the If and For instructions as well as AffineApplyOps. We do - // this as a prepass to avoid invalidating the walker with our rewrite. + // Collect all the For instructions as well as AffineIfOps and AffineApplyOps. + // We do this as a prepass to avoid invalidating the walker with our rewrite. function->walkInsts([&](Instruction *inst) { - if (isa<IfInst>(inst) || isa<ForInst>(inst)) + if (isa<ForInst>(inst)) instsToRewrite.push_back(inst); auto op = dyn_cast<OperationInst>(inst); - if (op && op->isa<AffineApplyOp>()) + if (op && (op->isa<AffineApplyOp>() || op->isa<AffineIfOp>())) instsToRewrite.push_back(inst); }); // Rewrite all of the ifs and fors. We walked the instructions in preorder, // so we know that we will rewrite them in the same order. for (auto *inst : instsToRewrite) - if (auto *ifInst = dyn_cast<IfInst>(inst)) { - if (lowerIfInst(ifInst)) - return failure(); - } else if (auto *forInst = dyn_cast<ForInst>(inst)) { + if (auto *forInst = dyn_cast<ForInst>(inst)) { if (lowerForInst(forInst)) return failure(); } else { auto op = cast<OperationInst>(inst); - if (lowerAffineApply(op->cast<AffineApplyOp>())) + if (auto ifOp = op->dyn_cast<AffineIfOp>()) { + if (lowerAffineIf(ifOp)) + return failure(); + } else if (lowerAffineApply(op->cast<AffineApplyOp>())) { return failure(); + } } return success(); diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index 09d961f85cd..2744b1d624c 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -20,6 +20,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/Dominance.h" #include "mlir/Analysis/LoopAnalysis.h" @@ -559,9 +560,6 @@ static bool instantiateMaterialization(Instruction *inst, if (isa<ForInst>(inst)) return inst->emitError("NYI path ForInst"); - if (isa<IfInst>(inst)) - return inst->emitError("NYI path IfInst"); - // Create a builder here for unroll-and-jam effects. FuncBuilder b(inst); auto *opInst = cast<OperationInst>(inst); @@ -570,6 +568,9 @@ static bool instantiateMaterialization(Instruction *inst, if (opInst->isa<AffineApplyOp>()) { return false; } + if (opInst->getNumBlockLists() != 0) + return inst->emitError("NYI path Op with region"); + if (auto write = opInst->dyn_cast<VectorTransferWriteOp>()) { auto *clone = instantiate(&b, write, state->hwVectorType, state->hwVectorInstance, state->substitutionsMap); diff --git a/mlir/lib/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Transforms/SimplifyAffineStructures.cpp index bd39e47786a..ba59123c700 100644 --- a/mlir/lib/Transforms/SimplifyAffineStructures.cpp +++ b/mlir/lib/Transforms/SimplifyAffineStructures.cpp @@ -28,7 +28,6 @@ #define DEBUG_TYPE "simplify-affine-structure" using namespace mlir; -using llvm::report_fatal_error; namespace { @@ -42,9 +41,6 @@ struct SimplifyAffineStructures : public FunctionPass { PassResult runOnFunction(Function *f) override; - void visitIfInst(IfInst *ifInst); - void visitOperationInst(OperationInst *opInst); - static char passID; }; @@ -66,28 +62,19 @@ static IntegerSet simplifyIntegerSet(IntegerSet set) { return set; } -void SimplifyAffineStructures::visitIfInst(IfInst *ifInst) { - auto set = ifInst->getCondition().getIntegerSet(); - ifInst->setIntegerSet(simplifyIntegerSet(set)); -} - -void SimplifyAffineStructures::visitOperationInst(OperationInst *opInst) { - for (auto attr : opInst->getAttrs()) { - if (auto mapAttr = attr.second.dyn_cast<AffineMapAttr>()) { - MutableAffineMap mMap(mapAttr.getValue()); - mMap.simplify(); - auto map = mMap.getAffineMap(); - opInst->setAttr(attr.first, AffineMapAttr::get(map)); - } - } -} - PassResult SimplifyAffineStructures::runOnFunction(Function *f) { - f->walkInsts([&](Instruction *inst) { - if (auto *opInst = dyn_cast<OperationInst>(inst)) - visitOperationInst(opInst); - if (auto *ifInst = dyn_cast<IfInst>(inst)) - visitIfInst(ifInst); + f->walkOps([&](OperationInst *opInst) { + for (auto attr : opInst->getAttrs()) { + if (auto mapAttr = attr.second.dyn_cast<AffineMapAttr>()) { + MutableAffineMap mMap(mapAttr.getValue()); + mMap.simplify(); + auto map = mMap.getAffineMap(); + opInst->setAttr(attr.first, AffineMapAttr::get(map)); + } else if (auto setAttr = attr.second.dyn_cast<IntegerSetAttr>()) { + auto simplified = simplifyIntegerSet(setAttr.getValue()); + opInst->setAttr(attr.first, IntegerSetAttr::get(simplified)); + } + } }); return success(); |

