diff options
36 files changed, 495 insertions, 542 deletions
diff --git a/mlir/include/mlir/AffineOps/AffineOps.h b/mlir/include/mlir/AffineOps/AffineOps.h new file mode 100644 index 00000000000..d511f628c3c --- /dev/null +++ b/mlir/include/mlir/AffineOps/AffineOps.h @@ -0,0 +1,91 @@ +//===- AffineOps.h - MLIR Affine Operations -------------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file defines convenience types for working with Affine operations +// in the MLIR instruction set. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_AFFINEOPS_AFFINEOPS_H +#define MLIR_AFFINEOPS_AFFINEOPS_H + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/StandardTypes.h" + +namespace mlir { + +class AffineOpsDialect : public Dialect { +public: + AffineOpsDialect(MLIRContext *context); +}; + +/// The "if" operation represents an if–then–else construct for conditionally +/// executing two regions of code. The operands to an if operation are an +/// IntegerSet condition and a set of symbol/dimension operands to the +/// condition set. The operation produces no results. For example: +/// +/// if #set(%i) { +/// ... +/// } else { +/// ... +/// } +/// +/// The 'else' blocks to the if operation are optional, and may be omitted. For +/// example: +/// +/// if #set(%i) { +/// ... +/// } +/// +class AffineIfOp + : public Op<AffineIfOp, OpTrait::VariadicOperands, OpTrait::ZeroResult> { +public: + // Hooks to customize behavior of this op. + static void build(Builder *builder, OperationState *result, + IntegerSet condition, ArrayRef<Value *> conditionOperands); + + static StringRef getOperationName() { return "if"; } + static StringRef getConditionAttrName() { return "condition"; } + + IntegerSet getIntegerSet() const; + void setIntegerSet(IntegerSet newSet); + + /// Returns the list of 'then' blocks. + BlockList &getThenBlocks(); + const BlockList &getThenBlocks() const { + return const_cast<AffineIfOp *>(this)->getThenBlocks(); + } + + /// Returns the list of 'else' blocks. + BlockList &getElseBlocks(); + const BlockList &getElseBlocks() const { + return const_cast<AffineIfOp *>(this)->getElseBlocks(); + } + + bool verify() const; + static bool parse(OpAsmParser *parser, OperationState *result); + void print(OpAsmPrinter *p) const; + +private: + friend class OperationInst; + explicit AffineIfOp(const OperationInst *state) : Op(state) {} +}; + +} // end namespace mlir + +#endif diff --git a/mlir/include/mlir/Analysis/NestedMatcher.h b/mlir/include/mlir/Analysis/NestedMatcher.h index c205d55488e..161bb217a10 100644 --- a/mlir/include/mlir/Analysis/NestedMatcher.h +++ b/mlir/include/mlir/Analysis/NestedMatcher.h @@ -128,7 +128,6 @@ private: void matchOne(Instruction *elem); void visitForInst(ForInst *forInst) { matchOne(forInst); } - void visitIfInst(IfInst *ifInst) { matchOne(ifInst); } void visitOperationInst(OperationInst *opInst) { matchOne(opInst); } /// POD paylod. diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h index 1b14d925d32..bc9563f847a 100644 --- a/mlir/include/mlir/IR/Block.h +++ b/mlir/include/mlir/IR/Block.h @@ -26,7 +26,6 @@ #include "llvm/ADT/PointerUnion.h" namespace mlir { -class IfInst; class BlockList; class BlockAndValueMapping; @@ -62,7 +61,7 @@ public: } /// Returns the function that this block is part of, even if the block is - /// nested under an IfInst or ForInst. + /// nested under an OperationInst or ForInst. Function *getFunction(); const Function *getFunction() const { return const_cast<Block *>(this)->getFunction(); @@ -325,7 +324,7 @@ private: namespace mlir { /// This class contains a list of basic blocks and has a notion of the object it -/// is part of - a Function or IfInst or ForInst. +/// is part of - a Function or OperationInst or ForInst. class BlockList { public: explicit BlockList(Function *container); @@ -365,15 +364,16 @@ public: return &BlockList::blocks; } - /// A BlockList is part of a Function or and IfInst/ForInst. If it is - /// part of an IfInst/ForInst, then return it, otherwise return null. + /// A BlockList is part of a function or an operation region. If it is + /// part of an operation region, then return the operation, otherwise return + /// null. Instruction *getContainingInst(); const Instruction *getContainingInst() const { return const_cast<BlockList *>(this)->getContainingInst(); } - /// A BlockList is part of a Function or and IfInst/ForInst. If it is - /// part of a Function, then return it, otherwise return null. + /// A BlockList is part of a function or an operation region. If it is part + /// of a Function, then return it, otherwise return null. Function *getContainingFunction(); const Function *getContainingFunction() const { return const_cast<BlockList *>(this)->getContainingFunction(); diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index 156bd02bb52..3271c12afde 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -286,10 +286,6 @@ public: // Default step is 1. ForInst *createFor(Location loc, int64_t lb, int64_t ub, int64_t step = 1); - /// Creates if instruction. - IfInst *createIf(Location location, ArrayRef<Value *> operands, - IntegerSet set); - private: Function *function; Block *block = nullptr; diff --git a/mlir/include/mlir/IR/InstVisitor.h b/mlir/include/mlir/IR/InstVisitor.h index b6a759e76f5..78810da909d 100644 --- a/mlir/include/mlir/IR/InstVisitor.h +++ b/mlir/include/mlir/IR/InstVisitor.h @@ -44,7 +44,7 @@ // lc.walk(function); // numLoops = lc.numLoops; // -// There are 'visit' methods for OperationInst, ForInst, IfInst, and +// There are 'visit' methods for OperationInst, ForInst, and // Function, which recursively process all contained instructions. // // Note that if you don't implement visitXXX for some instruction type, @@ -85,8 +85,6 @@ public: switch (s->getKind()) { case Instruction::Kind::For: return static_cast<SubClass *>(this)->visitForInst(cast<ForInst>(s)); - case Instruction::Kind::If: - return static_cast<SubClass *>(this)->visitIfInst(cast<IfInst>(s)); case Instruction::Kind::OperationInst: return static_cast<SubClass *>(this)->visitOperationInst( cast<OperationInst>(s)); @@ -104,7 +102,6 @@ public: // When visiting a for inst, if inst, or an operation inst directly, these // methods get called to indicate when transitioning into a new unit. void visitForInst(ForInst *forInst) {} - void visitIfInst(IfInst *ifInst) {} void visitOperationInst(OperationInst *opInst) {} }; @@ -166,23 +163,6 @@ public: static_cast<SubClass *>(this)->visitForInst(forInst); } - void walkIfInst(IfInst *ifInst) { - static_cast<SubClass *>(this)->visitIfInst(ifInst); - static_cast<SubClass *>(this)->walk(ifInst->getThen()->begin(), - ifInst->getThen()->end()); - if (auto *elseBlock = ifInst->getElse()) - static_cast<SubClass *>(this)->walk(elseBlock->begin(), elseBlock->end()); - } - - void walkIfInstPostOrder(IfInst *ifInst) { - static_cast<SubClass *>(this)->walkPostOrder(ifInst->getThen()->begin(), - ifInst->getThen()->end()); - if (auto *elseBlock = ifInst->getElse()) - static_cast<SubClass *>(this)->walkPostOrder(elseBlock->begin(), - elseBlock->end()); - static_cast<SubClass *>(this)->visitIfInst(ifInst); - } - // Function to walk a instruction. RetTy walk(Instruction *s) { static_assert(std::is_base_of<InstWalker, SubClass>::value, @@ -193,8 +173,6 @@ public: switch (s->getKind()) { case Instruction::Kind::For: return static_cast<SubClass *>(this)->walkForInst(cast<ForInst>(s)); - case Instruction::Kind::If: - return static_cast<SubClass *>(this)->walkIfInst(cast<IfInst>(s)); case Instruction::Kind::OperationInst: return static_cast<SubClass *>(this)->walkOpInst(cast<OperationInst>(s)); } @@ -210,9 +188,6 @@ public: case Instruction::Kind::For: return static_cast<SubClass *>(this)->walkForInstPostOrder( cast<ForInst>(s)); - case Instruction::Kind::If: - return static_cast<SubClass *>(this)->walkIfInstPostOrder( - cast<IfInst>(s)); case Instruction::Kind::OperationInst: return static_cast<SubClass *>(this)->walkOpInstPostOrder( cast<OperationInst>(s)); @@ -231,7 +206,6 @@ public: // processing their descendants in some way. When using RetTy, all of these // need to be overridden. void visitForInst(ForInst *forInst) {} - void visitIfInst(IfInst *ifInst) {} void visitOperationInst(OperationInst *opInst) {} void visitInstruction(Instruction *inst) {} }; diff --git a/mlir/include/mlir/IR/Instruction.h b/mlir/include/mlir/IR/Instruction.h index 6a296b7348e..3dc1e76dd20 100644 --- a/mlir/include/mlir/IR/Instruction.h +++ b/mlir/include/mlir/IR/Instruction.h @@ -75,7 +75,6 @@ public: enum class Kind { OperationInst = (int)IROperandOwner::Kind::OperationInst, For = (int)IROperandOwner::Kind::ForInst, - If = (int)IROperandOwner::Kind::IfInst, }; Kind getKind() const { return (Kind)IROperandOwner::getKind(); } diff --git a/mlir/include/mlir/IR/Instructions.h b/mlir/include/mlir/IR/Instructions.h index 71d832b8b90..fb6b1b97ca0 100644 --- a/mlir/include/mlir/IR/Instructions.h +++ b/mlir/include/mlir/IR/Instructions.h @@ -794,130 +794,6 @@ private: friend class ForInst; }; - -/// If instruction restricts execution to a subset of the loop iteration space. -class IfInst : public Instruction { -public: - static IfInst *create(Location location, ArrayRef<Value *> operands, - IntegerSet set); - ~IfInst(); - - //===--------------------------------------------------------------------===// - // Then, else, condition. - //===--------------------------------------------------------------------===// - - Block *getThen() { return &thenClause.front(); } - const Block *getThen() const { return &thenClause.front(); } - Block *getElse() { return elseClause ? &elseClause->front() : nullptr; } - const Block *getElse() const { - return elseClause ? &elseClause->front() : nullptr; - } - bool hasElse() const { return elseClause != nullptr; } - - Block *createElse() { - assert(elseClause == nullptr && "already has an else clause!"); - elseClause = new BlockList(this); - elseClause->push_back(new Block()); - return &elseClause->front(); - } - - const AffineCondition getCondition() const; - - IntegerSet getIntegerSet() const { return set; } - void setIntegerSet(IntegerSet newSet) { - assert(newSet.getNumOperands() == operands.size()); - set = newSet; - } - - //===--------------------------------------------------------------------===// - // Operands - //===--------------------------------------------------------------------===// - - /// Operand iterators. - using operand_iterator = OperandIterator<IfInst, Value>; - using const_operand_iterator = OperandIterator<const IfInst, const Value>; - - /// Operand iterator range. - using operand_range = llvm::iterator_range<operand_iterator>; - using const_operand_range = llvm::iterator_range<const_operand_iterator>; - - unsigned getNumOperands() const { return operands.size(); } - - Value *getOperand(unsigned idx) { return getInstOperand(idx).get(); } - const Value *getOperand(unsigned idx) const { - return getInstOperand(idx).get(); - } - void setOperand(unsigned idx, Value *value) { - getInstOperand(idx).set(value); - } - - operand_iterator operand_begin() { return operand_iterator(this, 0); } - operand_iterator operand_end() { - return operand_iterator(this, getNumOperands()); - } - - const_operand_iterator operand_begin() const { - return const_operand_iterator(this, 0); - } - const_operand_iterator operand_end() const { - return const_operand_iterator(this, getNumOperands()); - } - - ArrayRef<InstOperand> getInstOperands() const { return operands; } - MutableArrayRef<InstOperand> getInstOperands() { return operands; } - InstOperand &getInstOperand(unsigned idx) { return getInstOperands()[idx]; } - const InstOperand &getInstOperand(unsigned idx) const { - return getInstOperands()[idx]; - } - - //===--------------------------------------------------------------------===// - // Other - //===--------------------------------------------------------------------===// - - MLIRContext *getContext() const; - - /// Methods for support type inquiry through isa, cast, and dyn_cast. - static bool classof(const IROperandOwner *ptr) { - return ptr->getKind() == IROperandOwner::Kind::IfInst; - } - -private: - // it is always present. - BlockList thenClause; - // 'else' clause of the if instruction. 'nullptr' if there is no else clause. - BlockList *elseClause; - - // The integer set capturing the conditional guard. - IntegerSet set; - - // Condition operands. - std::vector<InstOperand> operands; - - explicit IfInst(Location location, unsigned numOperands, IntegerSet set); -}; - -/// AffineCondition represents a condition of the 'if' instruction. -/// Its life span should not exceed that of the objects it refers to. -/// AffineCondition does not provide its own methods for iterating over -/// the operands since the iterators of the if instruction accomplish -/// the same purpose. -/// -/// AffineCondition is trivially copyable, so it should be passed by value. -class AffineCondition { -public: - const IfInst *getIfInst() const { return &inst; } - IntegerSet getIntegerSet() const { return set; } - -private: - // 'if' instruction that contains this affine condition. - const IfInst &inst; - // Integer set for this affine condition. - IntegerSet set; - - AffineCondition(const IfInst &inst, IntegerSet set) : inst(inst), set(set) {} - - friend class IfInst; -}; } // end namespace mlir #endif // MLIR_IR_INSTRUCTIONS_H diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h index 1e319db3571..d3a5d35427f 100644 --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -89,6 +89,9 @@ public: /// Print the entire operation with the default generic assembly form. virtual void printGenericOp(const OperationInst *op) = 0; + /// Prints a block list. + virtual void printBlockList(const BlockList &blocks) = 0; + private: OpAsmPrinter(const OpAsmPrinter &) = delete; void operator=(const OpAsmPrinter &) = delete; @@ -195,7 +198,19 @@ public: virtual bool parseColonTypeList(SmallVectorImpl<Type> &result) = 0; /// Parse a keyword followed by a type. - virtual bool parseKeywordType(const char *keyword, Type &result) = 0; + bool parseKeywordType(const char *keyword, Type &result) { + return parseKeyword(keyword) || parseType(result); + } + + /// Parse a keyword. + bool parseKeyword(const char *keyword) { + if (parseOptionalKeyword(keyword)) + return emitError(getNameLoc(), "expected '" + Twine(keyword) + "'"); + return false; + } + + /// If a keyword is present, then parse it. + virtual bool parseOptionalKeyword(const char *keyword) = 0; /// Add the specified type to the end of the specified type list and return /// false. This is a helper designed to allow parse methods to be simple and @@ -296,6 +311,10 @@ public: int requiredOperandCount = -1, Delimiter delimiter = Delimiter::None) = 0; + /// Parses a block list. Any parsed blocks are filled in to the + /// operation's block lists after the operation is created. + virtual bool parseBlockList() = 0; + //===--------------------------------------------------------------------===// // Methods for interacting with the parser //===--------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/UseDefLists.h b/mlir/include/mlir/IR/UseDefLists.h index 053d3520103..80cd21362ce 100644 --- a/mlir/include/mlir/IR/UseDefLists.h +++ b/mlir/include/mlir/IR/UseDefLists.h @@ -81,10 +81,9 @@ public: enum class Kind { OperationInst, ForInst, - IfInst, /// These enums define ranges used for classof implementations. - INST_LAST = IfInst, + INST_LAST = ForInst, }; Kind getKind() const { return locationAndKind.getInt(); } diff --git a/mlir/include/mlir/Transforms/MLPatternLoweringPass.h b/mlir/include/mlir/Transforms/MLPatternLoweringPass.h index 978fa45ab23..00c6577240c 100644 --- a/mlir/include/mlir/Transforms/MLPatternLoweringPass.h +++ b/mlir/include/mlir/Transforms/MLPatternLoweringPass.h @@ -93,7 +93,7 @@ using OwningMLLoweringPatternList = /// next _original_ operation is considered. /// In other words, for each operation, the pass applies the first matching /// rewriter in the list and advances to the (lexically) next operation. -/// Non-operation instructions (ForInst and IfInst) are ignored. +/// Non-operation instructions (ForInst) are ignored. /// This is similar to greedy worklist-based pattern rewriter, except that this /// operates on ML functions using an ML builder and does not maintain the work /// list. Note that, as of the time of writing, worklist-based rewriter did not diff --git a/mlir/lib/AffineOps/AffineOps.cpp b/mlir/lib/AffineOps/AffineOps.cpp new file mode 100644 index 00000000000..5b29467fc44 --- /dev/null +++ b/mlir/lib/AffineOps/AffineOps.cpp @@ -0,0 +1,151 @@ +//===- AffineOps.cpp - MLIR Affine Operations -----------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "mlir/AffineOps/AffineOps.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/OpImplementation.h" +using namespace mlir; + +//===----------------------------------------------------------------------===// +// AffineOpsDialect +//===----------------------------------------------------------------------===// + +AffineOpsDialect::AffineOpsDialect(MLIRContext *context) + : Dialect(/*namePrefix=*/"", context) { + addOperations<AffineIfOp>(); +} + +//===----------------------------------------------------------------------===// +// AffineIfOp +//===----------------------------------------------------------------------===// + +void AffineIfOp::build(Builder *builder, OperationState *result, + IntegerSet condition, + ArrayRef<Value *> conditionOperands) { + result->addAttribute(getConditionAttrName(), IntegerSetAttr::get(condition)); + result->addOperands(conditionOperands); + + // Reserve 2 block lists, one for the 'then' and one for the 'else' regions. + result->reserveBlockLists(2); +} + +bool AffineIfOp::verify() const { + // Verify that we have a condition attribute. + auto conditionAttr = getAttrOfType<IntegerSetAttr>(getConditionAttrName()); + if (!conditionAttr) + return emitOpError("requires an integer set attribute named 'condition'"); + + // Verify that the operands are valid dimension/symbols. + IntegerSet condition = conditionAttr.getValue(); + for (unsigned i = 0, e = getNumOperands(); i != e; ++i) { + const Value *operand = getOperand(i); + if (i < condition.getNumDims() && !operand->isValidDim()) + return emitOpError("operand cannot be used as a dimension id"); + if (i >= condition.getNumDims() && !operand->isValidSymbol()) + return emitOpError("operand cannot be used as a symbol"); + } + + // Verify that the entry of each child blocklist does not have arguments. + for (const auto &blockList : getInstruction()->getBlockLists()) { + if (blockList.empty()) + continue; + + // TODO(riverriddle) We currently do not allow multiple blocks in child + // block lists. + if (std::next(blockList.begin()) != blockList.end()) + return emitOpError( + "expects only one block per 'if' or 'else' block list"); + if (blockList.front().getTerminator()) + return emitOpError("expects region block to not have a terminator"); + + for (const auto &b : blockList) + if (b.getNumArguments() != 0) + return emitOpError( + "requires that child entry blocks have no arguments"); + } + return false; +} + +bool AffineIfOp::parse(OpAsmParser *parser, OperationState *result) { + // Parse the condition attribute set. + IntegerSetAttr conditionAttr; + unsigned numDims; + if (parser->parseAttribute(conditionAttr, getConditionAttrName().data(), + result->attributes) || + parseDimAndSymbolList(parser, result->operands, numDims)) + return true; + + // Verify the condition operands. + auto set = conditionAttr.getValue(); + if (set.getNumDims() != numDims) + return parser->emitError( + parser->getNameLoc(), + "dim operand count and integer set dim count must match"); + if (numDims + set.getNumSymbols() != result->operands.size()) + return parser->emitError( + parser->getNameLoc(), + "symbol operand count and integer set symbol count must match"); + + // Parse the 'then' block list. + if (parser->parseBlockList()) + return true; + + // If we find an 'else' keyword then parse the else block list. + if (!parser->parseOptionalKeyword("else")) { + if (parser->parseBlockList()) + return true; + } + + // Reserve 2 block lists, one for the 'then' and one for the 'else' regions. + result->reserveBlockLists(2); + return false; +} + +void AffineIfOp::print(OpAsmPrinter *p) const { + auto conditionAttr = getAttrOfType<IntegerSetAttr>(getConditionAttrName()); + *p << "if " << conditionAttr; + printDimAndSymbolList(operand_begin(), operand_end(), + conditionAttr.getValue().getNumDims(), p); + p->printBlockList(getInstruction()->getBlockList(0)); + + // Print the 'else' block list if it has any blocks. + const auto &elseBlockList = getInstruction()->getBlockList(1); + if (!elseBlockList.empty()) { + *p << " else"; + p->printBlockList(elseBlockList); + } +} + +IntegerSet AffineIfOp::getIntegerSet() const { + return getAttrOfType<IntegerSetAttr>(getConditionAttrName()).getValue(); +} +void AffineIfOp::setIntegerSet(IntegerSet newSet) { + setAttr( + Identifier::get(getConditionAttrName(), getInstruction()->getContext()), + IntegerSetAttr::get(newSet)); +} + +/// Returns the list of 'then' blocks. +BlockList &AffineIfOp::getThenBlocks() { + return getInstruction()->getBlockList(0); +} + +/// Returns the list of 'else' blocks. +BlockList &AffineIfOp::getElseBlocks() { + return getInstruction()->getBlockList(1); +} diff --git a/mlir/lib/AffineOps/DialectRegistration.cpp b/mlir/lib/AffineOps/DialectRegistration.cpp new file mode 100644 index 00000000000..0afb32c1bd6 --- /dev/null +++ b/mlir/lib/AffineOps/DialectRegistration.cpp @@ -0,0 +1,22 @@ +//===- DialectRegistration.cpp - Register Affine Op dialect ---------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "mlir/AffineOps/AffineOps.h" +using namespace mlir; + +// Static initialization for Affine op dialect registration. +static DialectRegistration<AffineOpsDialect> StandardOps; diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp index 219f356807a..07c903a6613 100644 --- a/mlir/lib/Analysis/LoopAnalysis.cpp +++ b/mlir/lib/Analysis/LoopAnalysis.cpp @@ -21,6 +21,7 @@ #include "mlir/Analysis/LoopAnalysis.h" +#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/NestedMatcher.h" @@ -246,6 +247,16 @@ static bool isVectorizableLoopWithCond(const ForInst &loop, return false; } + // No vectorization across unknown regions. + auto regions = matcher::Op([](const Instruction &inst) -> bool { + auto &opInst = cast<OperationInst>(inst); + return opInst.getNumBlockLists() != 0 && !opInst.isa<AffineIfOp>(); + }); + auto regionsMatched = regions.match(forInst); + if (!regionsMatched.empty()) { + return false; + } + auto vectorTransfers = matcher::Op(isVectorTransferReadOrWrite); auto vectorTransfersMatched = vectorTransfers.match(forInst); if (!vectorTransfersMatched.empty()) { diff --git a/mlir/lib/Analysis/NestedMatcher.cpp b/mlir/lib/Analysis/NestedMatcher.cpp index 4f32e9b22f4..491a9bef1b9 100644 --- a/mlir/lib/Analysis/NestedMatcher.cpp +++ b/mlir/lib/Analysis/NestedMatcher.cpp @@ -16,6 +16,7 @@ // ============================================================================= #include "mlir/Analysis/NestedMatcher.h" +#include "mlir/AffineOps/AffineOps.h" #include "mlir/StandardOps/StandardOps.h" #include "llvm/ADT/ArrayRef.h" @@ -186,6 +187,11 @@ FilterFunctionType NestedPattern::getFilterFunction() { return storage->filter; } +static bool isAffineIfOp(const Instruction &inst) { + return isa<OperationInst>(inst) && + cast<OperationInst>(inst).isa<AffineIfOp>(); +} + namespace mlir { namespace matcher { @@ -194,16 +200,22 @@ NestedPattern Op(FilterFunctionType filter) { } NestedPattern If(NestedPattern child) { - return NestedPattern(Instruction::Kind::If, child, defaultFilterFunction); + return NestedPattern(Instruction::Kind::OperationInst, child, isAffineIfOp); } NestedPattern If(FilterFunctionType filter, NestedPattern child) { - return NestedPattern(Instruction::Kind::If, child, filter); + return NestedPattern(Instruction::Kind::OperationInst, child, + [filter](const Instruction &inst) { + return isAffineIfOp(inst) && filter(inst); + }); } NestedPattern If(ArrayRef<NestedPattern> nested) { - return NestedPattern(Instruction::Kind::If, nested, defaultFilterFunction); + return NestedPattern(Instruction::Kind::OperationInst, nested, isAffineIfOp); } NestedPattern If(FilterFunctionType filter, ArrayRef<NestedPattern> nested) { - return NestedPattern(Instruction::Kind::If, nested, filter); + return NestedPattern(Instruction::Kind::OperationInst, nested, + [filter](const Instruction &inst) { + return isAffineIfOp(inst) && filter(inst); + }); } NestedPattern For(NestedPattern child) { diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index 939a2ede618..0e77d4d9084 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -22,6 +22,7 @@ #include "mlir/Analysis/Utils.h" +#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/AffineStructures.h" #include "mlir/IR/Builders.h" @@ -43,7 +44,7 @@ void mlir::getLoopIVs(const Instruction &inst, // Traverse up the hierarchy collecing all 'for' instruction while skipping // over 'if' instructions. while (currInst && ((currForInst = dyn_cast<ForInst>(currInst)) || - isa<IfInst>(currInst))) { + cast<OperationInst>(currInst)->isa<AffineIfOp>())) { if (currForInst) loops->push_back(currForInst); currInst = currInst->getParentInst(); @@ -359,21 +360,12 @@ static Instruction *getInstAtPosition(ArrayRef<unsigned> positions, if (auto *childForInst = dyn_cast<ForInst>(&inst)) return getInstAtPosition(positions, level + 1, childForInst->getBody()); - if (auto *ifInst = dyn_cast<IfInst>(&inst)) { - auto *ret = getInstAtPosition(positions, level + 1, ifInst->getThen()); - if (ret != nullptr) - return ret; - if (auto *elseClause = ifInst->getElse()) - return getInstAtPosition(positions, level + 1, elseClause); - } - if (auto *opInst = dyn_cast<OperationInst>(&inst)) { - for (auto &blockList : opInst->getBlockLists()) { - for (auto &b : blockList) - if (auto *ret = getInstAtPosition(positions, level + 1, &b)) - return ret; - } - return nullptr; + for (auto &blockList : cast<OperationInst>(&inst)->getBlockLists()) { + for (auto &b : blockList) + if (auto *ret = getInstAtPosition(positions, level + 1, &b)) + return ret; } + return nullptr; } return nullptr; } diff --git a/mlir/lib/Analysis/Verifier.cpp b/mlir/lib/Analysis/Verifier.cpp index 383a4878c35..474eeb2a28e 100644 --- a/mlir/lib/Analysis/Verifier.cpp +++ b/mlir/lib/Analysis/Verifier.cpp @@ -73,7 +73,6 @@ public: bool verifyBlock(const Block &block, bool isTopLevel); bool verifyOperation(const OperationInst &op); bool verifyForInst(const ForInst &forInst); - bool verifyIfInst(const IfInst &ifInst); bool verifyDominance(const Block &block); bool verifyInstDominance(const Instruction &inst); @@ -180,10 +179,6 @@ bool FuncVerifier::verifyBlock(const Block &block, bool isTopLevel) { if (verifyForInst(cast<ForInst>(inst))) return true; break; - case Instruction::Kind::If: - if (verifyIfInst(cast<IfInst>(inst))) - return true; - break; } } @@ -250,18 +245,6 @@ bool FuncVerifier::verifyForInst(const ForInst &forInst) { return verifyBlock(*forInst.getBody(), /*isTopLevel=*/false); } -bool FuncVerifier::verifyIfInst(const IfInst &ifInst) { - // TODO: check that if conditions are properly formed. - if (verifyBlock(*ifInst.getThen(), /*isTopLevel*/ false)) - return true; - - if (auto *elseClause = ifInst.getElse()) - if (verifyBlock(*elseClause, /*isTopLevel*/ false)) - return true; - - return false; -} - bool FuncVerifier::verifyDominance(const Block &block) { for (auto &inst : block) { // Check that all operands on the instruction are ok. @@ -283,14 +266,6 @@ bool FuncVerifier::verifyDominance(const Block &block) { if (verifyDominance(*cast<ForInst>(inst).getBody())) return true; break; - case Instruction::Kind::If: - auto &ifInst = cast<IfInst>(inst); - if (verifyDominance(*ifInst.getThen())) - return true; - if (auto *elseClause = ifInst.getElse()) - if (verifyDominance(*elseClause)) - return true; - break; } } return false; diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 21bc3b824b1..cb4c1f0edce 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -145,7 +145,6 @@ private: // Visit functions. void visitInstruction(const Instruction *inst); void visitForInst(const ForInst *forInst); - void visitIfInst(const IfInst *ifInst); void visitOperationInst(const OperationInst *opInst); void visitType(Type type); void visitAttribute(Attribute attr); @@ -197,10 +196,6 @@ void ModuleState::visitAttribute(Attribute attr) { } } -void ModuleState::visitIfInst(const IfInst *ifInst) { - recordIntegerSetReference(ifInst->getIntegerSet()); -} - void ModuleState::visitForInst(const ForInst *forInst) { AffineMap lbMap = forInst->getLowerBoundMap(); if (!hasCustomForm(lbMap)) @@ -225,8 +220,6 @@ void ModuleState::visitOperationInst(const OperationInst *op) { void ModuleState::visitInstruction(const Instruction *inst) { switch (inst->getKind()) { - case Instruction::Kind::If: - return visitIfInst(cast<IfInst>(inst)); case Instruction::Kind::For: return visitForInst(cast<ForInst>(inst)); case Instruction::Kind::OperationInst: @@ -1077,7 +1070,6 @@ public: void print(const Instruction *inst); void print(const OperationInst *inst); void print(const ForInst *inst); - void print(const IfInst *inst); void print(const Block *block, bool printBlockArgs = true); void printOperation(const OperationInst *op); @@ -1125,6 +1117,9 @@ public: unsigned index) override; /// Print a block list. + void printBlockList(const BlockList &blocks) override { + printBlockList(blocks, /*printEntryBlockArgs=*/true); + } void printBlockList(const BlockList &blocks, bool printEntryBlockArgs) { os << " {\n"; if (!blocks.empty()) { @@ -1214,12 +1209,6 @@ void FunctionPrinter::numberValuesInBlock(const Block &block) { // Recursively number the stuff in the body. numberValuesInBlock(*cast<ForInst>(&inst)->getBody()); break; - case Instruction::Kind::If: { - auto *ifInst = cast<IfInst>(&inst); - numberValuesInBlock(*ifInst->getThen()); - if (auto *elseBlock = ifInst->getElse()) - numberValuesInBlock(*elseBlock); - } } } } @@ -1360,8 +1349,7 @@ void FunctionPrinter::printFunctionSignature() { } void FunctionPrinter::print(const Block *block, bool printBlockArgs) { - // Print the block label and argument list, unless this is the first block of - // the function, or the first block of an IfInst/ForInst with no arguments. + // Print the block label and argument list if requested. if (printBlockArgs) { os.indent(currentIndent); printBlockName(block); @@ -1418,8 +1406,6 @@ void FunctionPrinter::print(const Instruction *inst) { return print(cast<OperationInst>(inst)); case Instruction::Kind::For: return print(cast<ForInst>(inst)); - case Instruction::Kind::If: - return print(cast<IfInst>(inst)); } } @@ -1447,22 +1433,6 @@ void FunctionPrinter::print(const ForInst *inst) { os.indent(currentIndent) << "}"; } -void FunctionPrinter::print(const IfInst *inst) { - os.indent(currentIndent) << "if "; - IntegerSet set = inst->getIntegerSet(); - printIntegerSetReference(set); - printDimAndSymbolList(inst->getInstOperands(), set.getNumDims()); - printTrailingLocation(inst->getLoc()); - os << " {\n"; - print(inst->getThen(), /*printBlockArgs=*/false); - os.indent(currentIndent) << "}"; - if (inst->hasElse()) { - os << " else {\n"; - print(inst->getElse(), /*printBlockArgs=*/false); - os.indent(currentIndent) << "}"; - } -} - void FunctionPrinter::printValueID(const Value *value, bool printResultNo) const { int resultNo = -1; diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index 4471ff25e94..e174fdc1d00 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -327,10 +327,3 @@ ForInst *FuncBuilder::createFor(Location location, int64_t lb, int64_t ub, auto ubMap = AffineMap::getConstantMap(ub, context); return createFor(location, {}, lbMap, {}, ubMap, step); } - -IfInst *FuncBuilder::createIf(Location location, ArrayRef<Value *> operands, - IntegerSet set) { - auto *inst = IfInst::create(location, operands, set); - block->getInstructions().insert(insertPoint, inst); - return inst; -} diff --git a/mlir/lib/IR/Instruction.cpp b/mlir/lib/IR/Instruction.cpp index 6d74ed14257..0ccab2305ec 100644 --- a/mlir/lib/IR/Instruction.cpp +++ b/mlir/lib/IR/Instruction.cpp @@ -73,9 +73,6 @@ void Instruction::destroy() { case Kind::For: delete cast<ForInst>(this); break; - case Kind::If: - delete cast<IfInst>(this); - break; } } @@ -141,8 +138,6 @@ unsigned Instruction::getNumOperands() const { return cast<OperationInst>(this)->getNumOperands(); case Kind::For: return cast<ForInst>(this)->getNumOperands(); - case Kind::If: - return cast<IfInst>(this)->getNumOperands(); } } @@ -152,8 +147,6 @@ MutableArrayRef<InstOperand> Instruction::getInstOperands() { return cast<OperationInst>(this)->getInstOperands(); case Kind::For: return cast<ForInst>(this)->getInstOperands(); - case Kind::If: - return cast<IfInst>(this)->getInstOperands(); } } @@ -287,15 +280,6 @@ void Instruction::dropAllReferences() { // Make sure to drop references held by instructions within the body. cast<ForInst>(this)->getBody()->dropAllReferences(); break; - case Kind::If: { - // Make sure to drop references held by instructions within the 'then' and - // 'else' blocks. - auto *ifInst = cast<IfInst>(this); - ifInst->getThen()->dropAllReferences(); - if (auto *elseBlock = ifInst->getElse()) - elseBlock->dropAllReferences(); - break; - } case Kind::OperationInst: { auto *opInst = cast<OperationInst>(this); if (isTerminator()) @@ -810,54 +794,6 @@ mlir::extractForInductionVars(ArrayRef<ForInst *> forInsts) { return results; } //===----------------------------------------------------------------------===// -// IfInst -//===----------------------------------------------------------------------===// - -IfInst::IfInst(Location location, unsigned numOperands, IntegerSet set) - : Instruction(Kind::If, location), thenClause(this), elseClause(nullptr), - set(set) { - operands.reserve(numOperands); - - // The then of an 'if' inst always has one block. - thenClause.push_back(new Block()); -} - -IfInst::~IfInst() { - if (elseClause) - delete elseClause; - - // An IfInst's IntegerSet 'set' should not be deleted since it is - // allocated through MLIRContext's bump pointer allocator. -} - -IfInst *IfInst::create(Location location, ArrayRef<Value *> operands, - IntegerSet set) { - unsigned numOperands = operands.size(); - assert(numOperands == set.getNumOperands() && - "operand cound does not match the integer set operand count"); - - IfInst *inst = new IfInst(location, numOperands, set); - - for (auto *op : operands) - inst->operands.emplace_back(InstOperand(inst, op)); - - return inst; -} - -const AffineCondition IfInst::getCondition() const { - return AffineCondition(*this, set); -} - -MLIRContext *IfInst::getContext() const { - // Check for degenerate case of if instruction with no operands. - // This is unlikely, but legal. - if (operands.empty()) - return getFunction()->getContext(); - - return getOperand(0)->getType().getContext(); -} - -//===----------------------------------------------------------------------===// // Instruction Cloning //===----------------------------------------------------------------------===// @@ -931,40 +867,23 @@ Instruction *Instruction::clone(BlockAndValueMapping &mapper, for (auto *opValue : getOperands()) operands.push_back(mapper.lookupOrDefault(const_cast<Value *>(opValue))); - if (auto *forInst = dyn_cast<ForInst>(this)) { - auto lbMap = forInst->getLowerBoundMap(); - auto ubMap = forInst->getUpperBoundMap(); + // Otherwise, this must be a ForInst. + auto *forInst = cast<ForInst>(this); + auto lbMap = forInst->getLowerBoundMap(); + auto ubMap = forInst->getUpperBoundMap(); - auto *newFor = ForInst::create( - getLoc(), ArrayRef<Value *>(operands).take_front(lbMap.getNumInputs()), - lbMap, ArrayRef<Value *>(operands).take_back(ubMap.getNumInputs()), - ubMap, forInst->getStep()); + auto *newFor = ForInst::create( + getLoc(), ArrayRef<Value *>(operands).take_front(lbMap.getNumInputs()), + lbMap, ArrayRef<Value *>(operands).take_back(ubMap.getNumInputs()), ubMap, + forInst->getStep()); - // Remember the induction variable mapping. - mapper.map(forInst->getInductionVar(), newFor->getInductionVar()); - - // Recursively clone the body of the for loop. - for (auto &subInst : *forInst->getBody()) - newFor->getBody()->push_back(subInst.clone(mapper, context)); - - return newFor; - } - - // Otherwise, we must have an If instruction. - auto *ifInst = cast<IfInst>(this); - auto *newIf = IfInst::create(getLoc(), operands, ifInst->getIntegerSet()); - - auto *resultThen = newIf->getThen(); - for (auto &childInst : *ifInst->getThen()) - resultThen->push_back(childInst.clone(mapper, context)); - - if (ifInst->hasElse()) { - auto *resultElse = newIf->createElse(); - for (auto &childInst : *ifInst->getElse()) - resultElse->push_back(childInst.clone(mapper, context)); - } + // Remember the induction variable mapping. + mapper.map(forInst->getInductionVar(), newFor->getInductionVar()); - return newIf; + // Recursively clone the body of the for loop. + for (auto &subInst : *forInst->getBody()) + newFor->getBody()->push_back(subInst.clone(mapper, context)); + return newFor; } Instruction *Instruction::clone(MLIRContext *context) const { diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index 099b218892f..2ab151f8913 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -281,7 +281,7 @@ bool OpTrait::impl::verifyIsTerminator(const OperationInst *op) { if (!block || &block->back() != op) return op->emitOpError("must be the last instruction in the parent block"); - // Terminators may not exist in ForInst and IfInst. + // TODO(riverriddle) Terminators may not exist with an operation region. if (block->getContainingInst()) return op->emitOpError("may only be at the top level of a function"); diff --git a/mlir/lib/IR/Value.cpp b/mlir/lib/IR/Value.cpp index 6418b062dc1..7103eeb7389 100644 --- a/mlir/lib/IR/Value.cpp +++ b/mlir/lib/IR/Value.cpp @@ -66,8 +66,6 @@ MLIRContext *IROperandOwner::getContext() const { return cast<OperationInst>(this)->getContext(); case Kind::ForInst: return cast<ForInst>(this)->getContext(); - case Kind::IfInst: - return cast<IfInst>(this)->getContext(); } } diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index c477ad1bbc5..e5d6aa46565 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -996,8 +996,7 @@ Attribute Parser::parseAttribute(Type type) { AffineMap map; IntegerSet set; if (parseAffineMapOrIntegerSetReference(map, set)) - return (emitError("expected affine map or integer set attribute value"), - nullptr); + return nullptr; if (map) return builder.getAffineMapAttr(map); assert(set); @@ -2209,8 +2208,6 @@ public: const char *affineStructName); ParseResult parseBound(SmallVectorImpl<Value *> &operands, AffineMap &map, bool isLower); - ParseResult parseIfInst(); - ParseResult parseElseClause(Block *elseClause); ParseResult parseInstructions(Block *block); private: @@ -2392,10 +2389,6 @@ ParseResult FunctionParser::parseBlockBody(Block *block) { if (parseForInst()) return ParseFailure; break; - case Token::kw_if: - if (parseIfInst()) - return ParseFailure; - break; } } @@ -2935,12 +2928,18 @@ public: return false; } - /// Parse a keyword followed by a type. - bool parseKeywordType(const char *keyword, Type &result) override { - if (parser.getTokenSpelling() != keyword) - return parser.emitError("expected '" + Twine(keyword) + "'"); - parser.consumeToken(); - return !(result = parser.parseType()); + /// Parse an optional keyword. + bool parseOptionalKeyword(const char *keyword) override { + // Check that the current token is a bare identifier or keyword. + if (parser.getToken().isNot(Token::bare_identifier) && + !parser.getToken().isKeyword()) + return true; + + if (parser.getTokenSpelling() == keyword) { + parser.consumeToken(); + return false; + } + return true; } /// Parse an arbitrary attribute of a given type and return it in result. This @@ -3078,6 +3077,15 @@ public: return result == nullptr; } + /// Parses a list of blocks. + bool parseBlockList() override { + SmallVector<Block *, 2> results; + if (parser.parseOperationBlockList(results)) + return true; + parsedBlockLists.emplace_back(results); + return false; + } + //===--------------------------------------------------------------------===// // Methods for interacting with the parser //===--------------------------------------------------------------------===// @@ -3099,6 +3107,11 @@ public: /// Emit a diagnostic at the specified location and return true. bool emitError(llvm::SMLoc loc, const Twine &message) override { + // If we emit an error, then cleanup any parsed block lists. + for (auto &blockList : parsedBlockLists) + parser.cleanupInvalidBlocks(blockList); + parsedBlockLists.clear(); + parser.emitError(loc, "custom op '" + Twine(opName) + "' " + message); emittedError = true; return true; @@ -3106,7 +3119,13 @@ public: bool didEmitError() const { return emittedError; } + /// Returns the block lists that were parsed. + MutableArrayRef<SmallVector<Block *, 2>> getParsedBlockLists() { + return parsedBlockLists; + } + private: + std::vector<SmallVector<Block *, 2>> parsedBlockLists; SMLoc nameLoc; StringRef opName; FunctionParser &parser; @@ -3145,8 +3164,25 @@ OperationInst *FunctionParser::parseCustomOperation() { if (opAsmParser.didEmitError()) return nullptr; + // Check that enough block lists were reserved for those that were parsed. + auto parsedBlockLists = opAsmParser.getParsedBlockLists(); + if (parsedBlockLists.size() > opState.numBlockLists) { + opAsmParser.emitError( + opLoc, + "parsed more block lists than those reserved in the operation state"); + return nullptr; + } + // Otherwise, we succeeded. Use the state it parsed as our op information. - return builder.createOperation(opState); + auto *opInst = builder.createOperation(opState); + + // Resolve any parsed block lists. + for (unsigned i = 0, e = parsedBlockLists.size(); i != e; ++i) { + auto &opBlockList = opInst->getBlockList(i).getBlocks(); + opBlockList.insert(opBlockList.end(), parsedBlockLists[i].begin(), + parsedBlockLists[i].end()); + } + return opInst; } /// For instruction. @@ -3438,69 +3474,6 @@ IntegerSet AffineParser::parseIntegerSetConstraints(unsigned numDims, return builder.getIntegerSet(numDims, numSymbols, constraints, isEqs); } -/// If instruction. -/// -/// ml-if-head ::= `if` ml-if-cond trailing-location? `{` inst* `}` -/// | ml-if-head `else` `if` ml-if-cond trailing-location? -/// `{` inst* `}` -/// ml-if-inst ::= ml-if-head -/// | ml-if-head `else` `{` inst* `}` -/// -ParseResult FunctionParser::parseIfInst() { - auto loc = getToken().getLoc(); - consumeToken(Token::kw_if); - - IntegerSet set = parseIntegerSetReference(); - if (!set) - return ParseFailure; - - SmallVector<Value *, 4> operands; - if (parseDimAndSymbolList(operands, set.getNumDims(), set.getNumOperands(), - "integer set")) - return ParseFailure; - - IfInst *ifInst = - builder.createIf(getEncodedSourceLocation(loc), operands, set); - - // Try to parse the optional trailing location. - if (parseOptionalTrailingLocation(ifInst)) - return ParseFailure; - - Block *thenClause = ifInst->getThen(); - - // When parsing of an if instruction body fails, the IR contains - // the if instruction with the portion of the body that has been - // successfully parsed. - if (parseToken(Token::l_brace, "expected '{' before instruction list") || - parseBlock(thenClause) || - parseToken(Token::r_brace, "expected '}' after instruction list")) - return ParseFailure; - - if (consumeIf(Token::kw_else)) { - auto *elseClause = ifInst->createElse(); - if (parseElseClause(elseClause)) - return ParseFailure; - } - - // Reset insertion point to the current block. - builder.setInsertionPointToEnd(ifInst->getBlock()); - - return ParseSuccess; -} - -ParseResult FunctionParser::parseElseClause(Block *elseClause) { - if (getToken().is(Token::kw_if)) { - builder.setInsertionPointToEnd(elseClause); - return parseIfInst(); - } - - if (parseToken(Token::l_brace, "expected '{' before instruction list") || - parseBlock(elseClause) || - parseToken(Token::r_brace, "expected '}' after instruction list")) - return ParseFailure; - return ParseSuccess; -} - //===----------------------------------------------------------------------===// // Top-level entity parsing. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Parser/TokenKinds.def b/mlir/lib/Parser/TokenKinds.def index 40e98b25cb3..ec00f98b3f5 100644 --- a/mlir/lib/Parser/TokenKinds.def +++ b/mlir/lib/Parser/TokenKinds.def @@ -91,7 +91,6 @@ TOK_KEYWORD(attributes) TOK_KEYWORD(bf16) TOK_KEYWORD(ceildiv) TOK_KEYWORD(dense) -TOK_KEYWORD(else) TOK_KEYWORD(splat) TOK_KEYWORD(f16) TOK_KEYWORD(f32) @@ -100,7 +99,6 @@ TOK_KEYWORD(false) TOK_KEYWORD(floordiv) TOK_KEYWORD(for) TOK_KEYWORD(func) -TOK_KEYWORD(if) TOK_KEYWORD(index) TOK_KEYWORD(loc) TOK_KEYWORD(max) diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp index c2e1636626d..afd18a49b79 100644 --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -188,16 +188,6 @@ void CSE::simplifyBlock(Block *bb) { simplifyBlock(cast<ForInst>(i).getBody()); break; } - case Instruction::Kind::If: { - auto &ifInst = cast<IfInst>(i); - if (auto *elseBlock = ifInst.getElse()) { - ScopedMapTy::ScopeTy scope(knownValues); - simplifyBlock(elseBlock); - } - ScopedMapTy::ScopeTy scope(knownValues); - simplifyBlock(ifInst.getThen()); - break; - } } } } diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index cee0a08a63c..eebbbe9daa7 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -19,6 +19,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/LoopAnalysis.h" @@ -99,16 +100,16 @@ public: SmallVector<ForInst *, 4> forInsts; SmallVector<OperationInst *, 4> loadOpInsts; SmallVector<OperationInst *, 4> storeOpInsts; - bool hasIfInst = false; + bool hasNonForRegion = false; void visitForInst(ForInst *forInst) { forInsts.push_back(forInst); } - void visitIfInst(IfInst *ifInst) { hasIfInst = true; } - void visitOperationInst(OperationInst *opInst) { - if (opInst->isa<LoadOp>()) + if (opInst->getNumBlockLists() != 0) + hasNonForRegion = true; + else if (opInst->isa<LoadOp>()) loadOpInsts.push_back(opInst); - if (opInst->isa<StoreOp>()) + else if (opInst->isa<StoreOp>()) storeOpInsts.push_back(opInst); } }; @@ -410,8 +411,8 @@ bool MemRefDependenceGraph::init(Function *f) { // all loads and store accesses it contains. LoopNestStateCollector collector; collector.walkForInst(forInst); - // Return false if IfInsts are found (not currently supported). - if (collector.hasIfInst) + // Return false if a non 'for' region was found (not currently supported). + if (collector.hasNonForRegion) return false; Node node(id++, &inst); for (auto *opInst : collector.loadOpInsts) { @@ -434,19 +435,18 @@ bool MemRefDependenceGraph::init(Function *f) { auto *memref = opInst->cast<LoadOp>()->getMemRef(); memrefAccesses[memref].insert(node.id); nodes.insert({node.id, node}); - } - if (auto storeOp = opInst->dyn_cast<StoreOp>()) { + } else if (auto storeOp = opInst->dyn_cast<StoreOp>()) { // Create graph node for top-level store op. Node node(id++, &inst); node.stores.push_back(opInst); auto *memref = opInst->cast<StoreOp>()->getMemRef(); memrefAccesses[memref].insert(node.id); nodes.insert({node.id, node}); + } else if (opInst->getNumBlockLists() != 0) { + // Return false if another region is found (not currently supported). + return false; } } - // Return false if IfInsts are found (not currently supported). - if (isa<IfInst>(&inst)) - return false; } // Walk memref access lists and add graph edges between dependent nodes. diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp index 39ef758833b..6d63e4afd2d 100644 --- a/mlir/lib/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Transforms/LoopUnroll.cpp @@ -119,15 +119,6 @@ PassResult LoopUnroll::runOnFunction(Function *f) { return true; } - bool walkIfInstPostOrder(IfInst *ifInst) { - bool hasInnerLoops = - walkPostOrder(ifInst->getThen()->begin(), ifInst->getThen()->end()); - if (ifInst->hasElse()) - hasInnerLoops |= - walkPostOrder(ifInst->getElse()->begin(), ifInst->getElse()->end()); - return hasInnerLoops; - } - bool walkOpInstPostOrder(OperationInst *opInst) { for (auto &blockList : opInst->getBlockLists()) for (auto &block : blockList) diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp index ab37ff63bad..f770684f519 100644 --- a/mlir/lib/Transforms/LowerAffine.cpp +++ b/mlir/lib/Transforms/LowerAffine.cpp @@ -20,6 +20,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/AffineOps/AffineOps.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" @@ -246,7 +247,7 @@ public: PassResult runOnFunction(Function *function) override; bool lowerForInst(ForInst *forInst); - bool lowerIfInst(IfInst *ifInst); + bool lowerAffineIf(AffineIfOp *ifOp); bool lowerAffineApply(AffineApplyOp *op); static char passID; @@ -409,7 +410,7 @@ bool LowerAffinePass::lowerForInst(ForInst *forInst) { // enabling easy nesting of "if" instructions and if-then-else-if chains. // // +--------------------------------+ -// | <code before the IfInst> | +// | <code before the AffineIfOp> | // | %zero = constant 0 : index | // | %v = affine_apply #expr1(%ops) | // | %c = cmpi "sge" %v, %zero | @@ -453,10 +454,11 @@ bool LowerAffinePass::lowerForInst(ForInst *forInst) { // v v // +--------------------------------+ // | continue: | -// | <code after the IfInst> | +// | <code after the AffineIfOp> | // +--------------------------------+ // -bool LowerAffinePass::lowerIfInst(IfInst *ifInst) { +bool LowerAffinePass::lowerAffineIf(AffineIfOp *ifOp) { + auto *ifInst = ifOp->getInstruction(); auto loc = ifInst->getLoc(); // Start by splitting the block containing the 'if' into two parts. The part @@ -466,22 +468,38 @@ bool LowerAffinePass::lowerIfInst(IfInst *ifInst) { auto *continueBlock = condBlock->splitBlock(ifInst); // Create a block for the 'then' code, inserting it between the cond and - // continue blocks. Move the instructions over from the IfInst and add a + // continue blocks. Move the instructions over from the AffineIfOp and add a // branch to the continuation point. Block *thenBlock = new Block(); thenBlock->insertBefore(continueBlock); - auto *oldThen = ifInst->getThen(); - thenBlock->getInstructions().splice(thenBlock->begin(), - oldThen->getInstructions(), - oldThen->begin(), oldThen->end()); + // If the 'then' block is not empty, then splice the instructions. + auto &oldThenBlocks = ifOp->getThenBlocks(); + if (!oldThenBlocks.empty()) { + // We currently only handle one 'then' block. + if (std::next(oldThenBlocks.begin()) != oldThenBlocks.end()) + return true; + + Block *oldThen = &oldThenBlocks.front(); + + thenBlock->getInstructions().splice(thenBlock->begin(), + oldThen->getInstructions(), + oldThen->begin(), oldThen->end()); + } + FuncBuilder builder(thenBlock); builder.create<BranchOp>(loc, continueBlock); // Handle the 'else' block the same way, but we skip it if we have no else // code. Block *elseBlock = continueBlock; - if (auto *oldElse = ifInst->getElse()) { + auto &oldElseBlocks = ifOp->getElseBlocks(); + if (!oldElseBlocks.empty()) { + // We currently only handle one 'else' block. + if (std::next(oldElseBlocks.begin()) != oldElseBlocks.end()) + return true; + + auto *oldElse = &oldElseBlocks.front(); elseBlock = new Block(); elseBlock->insertBefore(continueBlock); @@ -493,7 +511,7 @@ bool LowerAffinePass::lowerIfInst(IfInst *ifInst) { } // Ok, now we just have to handle the condition logic. - auto integerSet = ifInst->getCondition().getIntegerSet(); + auto integerSet = ifOp->getIntegerSet(); // Implement short-circuit logic. For each affine expression in the 'if' // condition, convert it into an affine map and call `affine_apply` to obtain @@ -593,29 +611,30 @@ bool LowerAffinePass::lowerAffineApply(AffineApplyOp *op) { PassResult LowerAffinePass::runOnFunction(Function *function) { SmallVector<Instruction *, 8> instsToRewrite; - // Collect all the If and For instructions as well as AffineApplyOps. We do - // this as a prepass to avoid invalidating the walker with our rewrite. + // Collect all the For instructions as well as AffineIfOps and AffineApplyOps. + // We do this as a prepass to avoid invalidating the walker with our rewrite. function->walkInsts([&](Instruction *inst) { - if (isa<IfInst>(inst) || isa<ForInst>(inst)) + if (isa<ForInst>(inst)) instsToRewrite.push_back(inst); auto op = dyn_cast<OperationInst>(inst); - if (op && op->isa<AffineApplyOp>()) + if (op && (op->isa<AffineApplyOp>() || op->isa<AffineIfOp>())) instsToRewrite.push_back(inst); }); // Rewrite all of the ifs and fors. We walked the instructions in preorder, // so we know that we will rewrite them in the same order. for (auto *inst : instsToRewrite) - if (auto *ifInst = dyn_cast<IfInst>(inst)) { - if (lowerIfInst(ifInst)) - return failure(); - } else if (auto *forInst = dyn_cast<ForInst>(inst)) { + if (auto *forInst = dyn_cast<ForInst>(inst)) { if (lowerForInst(forInst)) return failure(); } else { auto op = cast<OperationInst>(inst); - if (lowerAffineApply(op->cast<AffineApplyOp>())) + if (auto ifOp = op->dyn_cast<AffineIfOp>()) { + if (lowerAffineIf(ifOp)) + return failure(); + } else if (lowerAffineApply(op->cast<AffineApplyOp>())) { return failure(); + } } return success(); diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index 09d961f85cd..2744b1d624c 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -20,6 +20,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/Dominance.h" #include "mlir/Analysis/LoopAnalysis.h" @@ -559,9 +560,6 @@ static bool instantiateMaterialization(Instruction *inst, if (isa<ForInst>(inst)) return inst->emitError("NYI path ForInst"); - if (isa<IfInst>(inst)) - return inst->emitError("NYI path IfInst"); - // Create a builder here for unroll-and-jam effects. FuncBuilder b(inst); auto *opInst = cast<OperationInst>(inst); @@ -570,6 +568,9 @@ static bool instantiateMaterialization(Instruction *inst, if (opInst->isa<AffineApplyOp>()) { return false; } + if (opInst->getNumBlockLists() != 0) + return inst->emitError("NYI path Op with region"); + if (auto write = opInst->dyn_cast<VectorTransferWriteOp>()) { auto *clone = instantiate(&b, write, state->hwVectorType, state->hwVectorInstance, state->substitutionsMap); diff --git a/mlir/lib/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Transforms/SimplifyAffineStructures.cpp index bd39e47786a..ba59123c700 100644 --- a/mlir/lib/Transforms/SimplifyAffineStructures.cpp +++ b/mlir/lib/Transforms/SimplifyAffineStructures.cpp @@ -28,7 +28,6 @@ #define DEBUG_TYPE "simplify-affine-structure" using namespace mlir; -using llvm::report_fatal_error; namespace { @@ -42,9 +41,6 @@ struct SimplifyAffineStructures : public FunctionPass { PassResult runOnFunction(Function *f) override; - void visitIfInst(IfInst *ifInst); - void visitOperationInst(OperationInst *opInst); - static char passID; }; @@ -66,28 +62,19 @@ static IntegerSet simplifyIntegerSet(IntegerSet set) { return set; } -void SimplifyAffineStructures::visitIfInst(IfInst *ifInst) { - auto set = ifInst->getCondition().getIntegerSet(); - ifInst->setIntegerSet(simplifyIntegerSet(set)); -} - -void SimplifyAffineStructures::visitOperationInst(OperationInst *opInst) { - for (auto attr : opInst->getAttrs()) { - if (auto mapAttr = attr.second.dyn_cast<AffineMapAttr>()) { - MutableAffineMap mMap(mapAttr.getValue()); - mMap.simplify(); - auto map = mMap.getAffineMap(); - opInst->setAttr(attr.first, AffineMapAttr::get(map)); - } - } -} - PassResult SimplifyAffineStructures::runOnFunction(Function *f) { - f->walkInsts([&](Instruction *inst) { - if (auto *opInst = dyn_cast<OperationInst>(inst)) - visitOperationInst(opInst); - if (auto *ifInst = dyn_cast<IfInst>(inst)) - visitIfInst(ifInst); + f->walkOps([&](OperationInst *opInst) { + for (auto attr : opInst->getAttrs()) { + if (auto mapAttr = attr.second.dyn_cast<AffineMapAttr>()) { + MutableAffineMap mMap(mapAttr.getValue()); + mMap.simplify(); + auto map = mMap.getAffineMap(); + opInst->setAttr(attr.first, AffineMapAttr::get(map)); + } else if (auto setAttr = attr.second.dyn_cast<IntegerSetAttr>()) { + auto simplified = simplifyIntegerSet(setAttr.getValue()); + opInst->setAttr(attr.first, IntegerSetAttr::get(simplified)); + } + } }); return success(); diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir index bae112dd3b9..595991c0109 100644 --- a/mlir/test/IR/invalid.mlir +++ b/mlir/test/IR/invalid.mlir @@ -243,14 +243,6 @@ func @non_instruction() { // ----- -func @invalid_if_conditional1() { - for %i = 1 to 10 { - if () { // expected-error {{expected ':' or '['}} - } -} - -// ----- - func @invalid_if_conditional2() { for %i = 1 to 10 { if (i)[N] : (i >= ) // expected-error {{expected '== 0' or '>= 0' at end of affine constraint}} @@ -664,7 +656,11 @@ func @invalid_if_operands2(%N : index) { func @invalid_if_operands3(%N : index) { for %i = 1 to 10 { if #set0(%i)[%i] { - // expected-error@-1 {{value '%i' cannot be used as a symbol}} + // expected-error@-1 {{operand cannot be used as a symbol}} + } + } + return +} // ----- // expected-error@+1 {{expected '"' in string literal}} diff --git a/mlir/test/IR/locations.mlir b/mlir/test/IR/locations.mlir index e3e1bbbbfad..8a90d12bd03 100644 --- a/mlir/test/IR/locations.mlir +++ b/mlir/test/IR/locations.mlir @@ -16,9 +16,9 @@ func @inline_notation() -> i32 loc("mysource.cc":10:8) { for %i0 = 0 to 8 loc(fused["foo", "mysource.cc":10:8]) { } - // CHECK: ) loc(fused<"myPass">["foo", "foo2"]) - if #set0(%2) loc(fused<"myPass">["foo", "foo2"]) { - } + // CHECK: } loc(fused<"myPass">["foo", "foo2"]) + if #set0(%2) { + } loc(fused<"myPass">["foo", "foo2"]) // CHECK: return %0 : i32 loc(unknown) return %1 : i32 loc(unknown) diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir index 33109606538..626f24569c6 100644 --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -287,13 +287,15 @@ func @ifinst(%N: index) { // CHECK: %c1_i32 = constant 1 : i32 %y = "add"(%x, %i) : (i32, index) -> i32 // CHECK: %0 = "add"(%c1_i32, %i0) : (i32, index) -> i32 %z = "mul"(%y, %y) : (i32, i32) -> i32 // CHECK: %1 = "mul"(%0, %0) : (i32, i32) -> i32 - } else if (i)[N] : (i - 2 >= 0, 4 - i >= 0)(%i)[%N] { // CHECK } else if (#set1(%i0)[%arg0]) { - // CHECK: %c1 = constant 1 : index - %u = constant 1 : index - // CHECK: %2 = affine_apply #map{{.*}}(%i0, %i0)[%c1] - %w = affine_apply (d0,d1)[s0] -> (d0+d1+s0) (%i, %i) [%u] - } else { // CHECK } else { - %v = constant 3 : i32 // %c3_i32 = constant 3 : i32 + } else { // CHECK } else { + if (i)[N] : (i - 2 >= 0, 4 - i >= 0)(%i)[%N] { // CHECK if (#set1(%i0)[%arg0]) { + // CHECK: %c1 = constant 1 : index + %u = constant 1 : index + // CHECK: %2 = affine_apply #map{{.*}}(%i0, %i0)[%c1] + %w = affine_apply (d0,d1)[s0] -> (d0+d1+s0) (%i, %i) [%u] + } else { // CHECK } else { + %v = constant 3 : i32 // %c3_i32 = constant 3 : i32 + } } // CHECK } } // CHECK } return // CHECK return @@ -751,11 +753,11 @@ func @type_alias() -> !i32_type_alias { func @verbose_if(%N: index) { %c = constant 200 : index - // CHECK: "if"(%c200, %arg0, %c200) {cond: #set0} : (index, index, index) -> () { - "if"(%c, %N, %c) { cond: #set0 } : (index, index, index) -> () { + // CHECK: if #set0(%c200)[%arg0, %c200] { + "if"(%c, %N, %c) { condition: #set0 } : (index, index, index) -> () { // CHECK-NEXT: "add" %y = "add"(%c, %N) : (index, index) -> index - // CHECK-NEXT: } { + // CHECK-NEXT: } else { } { // The else block list. // CHECK-NEXT: "add" %z = "add"(%c, %c) : (index, index) -> index diff --git a/mlir/test/IR/pretty-locations.mlir b/mlir/test/IR/pretty-locations.mlir index cb2e14a56d5..69dace45165 100644 --- a/mlir/test/IR/pretty-locations.mlir +++ b/mlir/test/IR/pretty-locations.mlir @@ -21,10 +21,10 @@ func @inline_notation() -> i32 loc("mysource.cc":10:8) { for %i0 = 0 to 8 loc(fused["foo", "mysource.cc":10:8]) { } - // CHECK: ) <"myPass">["foo", "foo2"] - if #set0(%2) loc(fused<"myPass">["foo", "foo2"]) { - } + // CHECK: } <"myPass">["foo", "foo2"] + if #set0(%2) { + } loc(fused<"myPass">["foo", "foo2"]) // CHECK: return %0 : i32 [unknown] return %1 : i32 loc(unknown) -}
\ No newline at end of file +} diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir index d170ce590f7..162f193f662 100644 --- a/mlir/test/Transforms/loop-fusion.mlir +++ b/mlir/test/Transforms/loop-fusion.mlir @@ -483,7 +483,7 @@ func @should_not_fuse_if_inst_at_top_level() { %c0 = constant 4 : index if #set0(%c0) { } - // Top-level IfInst should prevent fusion. + // Top-level IfOp should prevent fusion. // CHECK: for %i0 = 0 to 10 { // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32> // CHECK-NEXT: } @@ -512,7 +512,7 @@ func @should_not_fuse_if_inst_in_loop_nest() { %v0 = load %m[%i1] : memref<10xf32> } - // IfInst in ForInst should prevent fusion. + // IfOp in ForInst should prevent fusion. // CHECK: for %i0 = 0 to 10 { // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32> // CHECK-NEXT: } diff --git a/mlir/test/Transforms/memref-dependence-check.mlir b/mlir/test/Transforms/memref-dependence-check.mlir index 6f6ad3fafc7..628044ed77a 100644 --- a/mlir/test/Transforms/memref-dependence-check.mlir +++ b/mlir/test/Transforms/memref-dependence-check.mlir @@ -10,7 +10,7 @@ func @store_may_execute_before_load() { %cf7 = constant 7.0 : f32 %c0 = constant 4 : index // There is a dependence from store 0 to load 1 at depth 1 because the - // ancestor IfInst of the store, dominates the ancestor ForSmt of the load, + // ancestor IfOp of the store, dominates the ancestor ForSmt of the load, // and thus the store "may" conditionally execute before the load. if #set0(%c0) { for %i0 = 0 to 10 { diff --git a/mlir/test/Transforms/strip-debug-info.mlir b/mlir/test/Transforms/strip-debug-info.mlir index 5509c7aba55..13f009deb70 100644 --- a/mlir/test/Transforms/strip-debug-info.mlir +++ b/mlir/test/Transforms/strip-debug-info.mlir @@ -13,10 +13,10 @@ func @inline_notation() -> i32 loc("mysource.cc":10:8) { for %i0 = 0 to 8 loc(fused["foo", "mysource.cc":10:8]) { } - // CHECK: if #set0(%c4) loc(unknown) + // CHECK: } loc(unknown) %2 = constant 4 : index - if #set0(%2) loc(fused<"myPass">["foo", "foo2"]) { - } + if #set0(%2) { + } loc(fused<"myPass">["foo", "foo2"]) // CHECK: return %0 : i32 loc(unknown) return %1 : i32 loc("bar") |

