summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--mlir/include/mlir/AffineOps/AffineOps.h91
-rw-r--r--mlir/include/mlir/Analysis/NestedMatcher.h1
-rw-r--r--mlir/include/mlir/IR/Block.h11
-rw-r--r--mlir/include/mlir/IR/Builders.h4
-rw-r--r--mlir/include/mlir/IR/InstVisitor.h28
-rw-r--r--mlir/include/mlir/IR/Instruction.h1
-rw-r--r--mlir/include/mlir/IR/Instructions.h124
-rw-r--r--mlir/include/mlir/IR/OpImplementation.h21
-rw-r--r--mlir/include/mlir/IR/UseDefLists.h3
-rw-r--r--mlir/include/mlir/Transforms/MLPatternLoweringPass.h2
-rw-r--r--mlir/lib/AffineOps/AffineOps.cpp151
-rw-r--r--mlir/lib/AffineOps/DialectRegistration.cpp22
-rw-r--r--mlir/lib/Analysis/LoopAnalysis.cpp11
-rw-r--r--mlir/lib/Analysis/NestedMatcher.cpp20
-rw-r--r--mlir/lib/Analysis/Utils.cpp22
-rw-r--r--mlir/lib/Analysis/Verifier.cpp25
-rw-r--r--mlir/lib/IR/AsmPrinter.cpp38
-rw-r--r--mlir/lib/IR/Builders.cpp7
-rw-r--r--mlir/lib/IR/Instruction.cpp109
-rw-r--r--mlir/lib/IR/Operation.cpp2
-rw-r--r--mlir/lib/IR/Value.cpp2
-rw-r--r--mlir/lib/Parser/Parser.cpp129
-rw-r--r--mlir/lib/Parser/TokenKinds.def2
-rw-r--r--mlir/lib/Transforms/CSE.cpp10
-rw-r--r--mlir/lib/Transforms/LoopFusion.cpp24
-rw-r--r--mlir/lib/Transforms/LoopUnroll.cpp9
-rw-r--r--mlir/lib/Transforms/LowerAffine.cpp59
-rw-r--r--mlir/lib/Transforms/MaterializeVectors.cpp7
-rw-r--r--mlir/lib/Transforms/SimplifyAffineStructures.cpp37
-rw-r--r--mlir/test/IR/invalid.mlir14
-rw-r--r--mlir/test/IR/locations.mlir6
-rw-r--r--mlir/test/IR/parser.mlir22
-rw-r--r--mlir/test/IR/pretty-locations.mlir8
-rw-r--r--mlir/test/Transforms/loop-fusion.mlir4
-rw-r--r--mlir/test/Transforms/memref-dependence-check.mlir2
-rw-r--r--mlir/test/Transforms/strip-debug-info.mlir6
36 files changed, 541 insertions, 493 deletions
diff --git a/mlir/include/mlir/AffineOps/AffineOps.h b/mlir/include/mlir/AffineOps/AffineOps.h
deleted file mode 100644
index d511f628c3c..00000000000
--- a/mlir/include/mlir/AffineOps/AffineOps.h
+++ /dev/null
@@ -1,91 +0,0 @@
-//===- AffineOps.h - MLIR Affine Operations -------------------------------===//
-//
-// Copyright 2019 The MLIR Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
-//
-// This file defines convenience types for working with Affine operations
-// in the MLIR instruction set.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_AFFINEOPS_AFFINEOPS_H
-#define MLIR_AFFINEOPS_AFFINEOPS_H
-
-#include "mlir/IR/Dialect.h"
-#include "mlir/IR/OpDefinition.h"
-#include "mlir/IR/StandardTypes.h"
-
-namespace mlir {
-
-class AffineOpsDialect : public Dialect {
-public:
- AffineOpsDialect(MLIRContext *context);
-};
-
-/// The "if" operation represents an if–then–else construct for conditionally
-/// executing two regions of code. The operands to an if operation are an
-/// IntegerSet condition and a set of symbol/dimension operands to the
-/// condition set. The operation produces no results. For example:
-///
-/// if #set(%i) {
-/// ...
-/// } else {
-/// ...
-/// }
-///
-/// The 'else' blocks to the if operation are optional, and may be omitted. For
-/// example:
-///
-/// if #set(%i) {
-/// ...
-/// }
-///
-class AffineIfOp
- : public Op<AffineIfOp, OpTrait::VariadicOperands, OpTrait::ZeroResult> {
-public:
- // Hooks to customize behavior of this op.
- static void build(Builder *builder, OperationState *result,
- IntegerSet condition, ArrayRef<Value *> conditionOperands);
-
- static StringRef getOperationName() { return "if"; }
- static StringRef getConditionAttrName() { return "condition"; }
-
- IntegerSet getIntegerSet() const;
- void setIntegerSet(IntegerSet newSet);
-
- /// Returns the list of 'then' blocks.
- BlockList &getThenBlocks();
- const BlockList &getThenBlocks() const {
- return const_cast<AffineIfOp *>(this)->getThenBlocks();
- }
-
- /// Returns the list of 'else' blocks.
- BlockList &getElseBlocks();
- const BlockList &getElseBlocks() const {
- return const_cast<AffineIfOp *>(this)->getElseBlocks();
- }
-
- bool verify() const;
- static bool parse(OpAsmParser *parser, OperationState *result);
- void print(OpAsmPrinter *p) const;
-
-private:
- friend class OperationInst;
- explicit AffineIfOp(const OperationInst *state) : Op(state) {}
-};
-
-} // end namespace mlir
-
-#endif
diff --git a/mlir/include/mlir/Analysis/NestedMatcher.h b/mlir/include/mlir/Analysis/NestedMatcher.h
index 161bb217a10..c205d55488e 100644
--- a/mlir/include/mlir/Analysis/NestedMatcher.h
+++ b/mlir/include/mlir/Analysis/NestedMatcher.h
@@ -128,6 +128,7 @@ private:
void matchOne(Instruction *elem);
void visitForInst(ForInst *forInst) { matchOne(forInst); }
+ void visitIfInst(IfInst *ifInst) { matchOne(ifInst); }
void visitOperationInst(OperationInst *opInst) { matchOne(opInst); }
/// POD paylod.
diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h
index e85ea772d0b..1b14d925d32 100644
--- a/mlir/include/mlir/IR/Block.h
+++ b/mlir/include/mlir/IR/Block.h
@@ -26,6 +26,7 @@
#include "llvm/ADT/PointerUnion.h"
namespace mlir {
+class IfInst;
class BlockList;
class BlockAndValueMapping;
@@ -61,7 +62,7 @@ public:
}
/// Returns the function that this block is part of, even if the block is
- /// nested under an OperationInst or ForInst.
+ /// nested under an IfInst or ForInst.
Function *getFunction();
const Function *getFunction() const {
return const_cast<Block *>(this)->getFunction();
@@ -324,7 +325,7 @@ private:
namespace mlir {
/// This class contains a list of basic blocks and has a notion of the object it
-/// is part of - a Function or OperationInst or ForInst.
+/// is part of - a Function or IfInst or ForInst.
class BlockList {
public:
explicit BlockList(Function *container);
@@ -364,14 +365,14 @@ public:
return &BlockList::blocks;
}
- /// A BlockList is part of a Function or and OperationInst/ForInst. If it is
- /// part of an OperationInst/ForInst, then return it, otherwise return null.
+ /// A BlockList is part of a Function or and IfInst/ForInst. If it is
+ /// part of an IfInst/ForInst, then return it, otherwise return null.
Instruction *getContainingInst();
const Instruction *getContainingInst() const {
return const_cast<BlockList *>(this)->getContainingInst();
}
- /// A BlockList is part of a Function or and OperationInst/ForInst. If it is
+ /// A BlockList is part of a Function or and IfInst/ForInst. If it is
/// part of a Function, then return it, otherwise return null.
Function *getContainingFunction();
const Function *getContainingFunction() const {
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 3271c12afde..156bd02bb52 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -286,6 +286,10 @@ public:
// Default step is 1.
ForInst *createFor(Location loc, int64_t lb, int64_t ub, int64_t step = 1);
+ /// Creates if instruction.
+ IfInst *createIf(Location location, ArrayRef<Value *> operands,
+ IntegerSet set);
+
private:
Function *function;
Block *block = nullptr;
diff --git a/mlir/include/mlir/IR/InstVisitor.h b/mlir/include/mlir/IR/InstVisitor.h
index 78810da909d..b6a759e76f5 100644
--- a/mlir/include/mlir/IR/InstVisitor.h
+++ b/mlir/include/mlir/IR/InstVisitor.h
@@ -44,7 +44,7 @@
// lc.walk(function);
// numLoops = lc.numLoops;
//
-// There are 'visit' methods for OperationInst, ForInst, and
+// There are 'visit' methods for OperationInst, ForInst, IfInst, and
// Function, which recursively process all contained instructions.
//
// Note that if you don't implement visitXXX for some instruction type,
@@ -85,6 +85,8 @@ public:
switch (s->getKind()) {
case Instruction::Kind::For:
return static_cast<SubClass *>(this)->visitForInst(cast<ForInst>(s));
+ case Instruction::Kind::If:
+ return static_cast<SubClass *>(this)->visitIfInst(cast<IfInst>(s));
case Instruction::Kind::OperationInst:
return static_cast<SubClass *>(this)->visitOperationInst(
cast<OperationInst>(s));
@@ -102,6 +104,7 @@ public:
// When visiting a for inst, if inst, or an operation inst directly, these
// methods get called to indicate when transitioning into a new unit.
void visitForInst(ForInst *forInst) {}
+ void visitIfInst(IfInst *ifInst) {}
void visitOperationInst(OperationInst *opInst) {}
};
@@ -163,6 +166,23 @@ public:
static_cast<SubClass *>(this)->visitForInst(forInst);
}
+ void walkIfInst(IfInst *ifInst) {
+ static_cast<SubClass *>(this)->visitIfInst(ifInst);
+ static_cast<SubClass *>(this)->walk(ifInst->getThen()->begin(),
+ ifInst->getThen()->end());
+ if (auto *elseBlock = ifInst->getElse())
+ static_cast<SubClass *>(this)->walk(elseBlock->begin(), elseBlock->end());
+ }
+
+ void walkIfInstPostOrder(IfInst *ifInst) {
+ static_cast<SubClass *>(this)->walkPostOrder(ifInst->getThen()->begin(),
+ ifInst->getThen()->end());
+ if (auto *elseBlock = ifInst->getElse())
+ static_cast<SubClass *>(this)->walkPostOrder(elseBlock->begin(),
+ elseBlock->end());
+ static_cast<SubClass *>(this)->visitIfInst(ifInst);
+ }
+
// Function to walk a instruction.
RetTy walk(Instruction *s) {
static_assert(std::is_base_of<InstWalker, SubClass>::value,
@@ -173,6 +193,8 @@ public:
switch (s->getKind()) {
case Instruction::Kind::For:
return static_cast<SubClass *>(this)->walkForInst(cast<ForInst>(s));
+ case Instruction::Kind::If:
+ return static_cast<SubClass *>(this)->walkIfInst(cast<IfInst>(s));
case Instruction::Kind::OperationInst:
return static_cast<SubClass *>(this)->walkOpInst(cast<OperationInst>(s));
}
@@ -188,6 +210,9 @@ public:
case Instruction::Kind::For:
return static_cast<SubClass *>(this)->walkForInstPostOrder(
cast<ForInst>(s));
+ case Instruction::Kind::If:
+ return static_cast<SubClass *>(this)->walkIfInstPostOrder(
+ cast<IfInst>(s));
case Instruction::Kind::OperationInst:
return static_cast<SubClass *>(this)->walkOpInstPostOrder(
cast<OperationInst>(s));
@@ -206,6 +231,7 @@ public:
// processing their descendants in some way. When using RetTy, all of these
// need to be overridden.
void visitForInst(ForInst *forInst) {}
+ void visitIfInst(IfInst *ifInst) {}
void visitOperationInst(OperationInst *opInst) {}
void visitInstruction(Instruction *inst) {}
};
diff --git a/mlir/include/mlir/IR/Instruction.h b/mlir/include/mlir/IR/Instruction.h
index 3dc1e76dd20..6a296b7348e 100644
--- a/mlir/include/mlir/IR/Instruction.h
+++ b/mlir/include/mlir/IR/Instruction.h
@@ -75,6 +75,7 @@ public:
enum class Kind {
OperationInst = (int)IROperandOwner::Kind::OperationInst,
For = (int)IROperandOwner::Kind::ForInst,
+ If = (int)IROperandOwner::Kind::IfInst,
};
Kind getKind() const { return (Kind)IROperandOwner::getKind(); }
diff --git a/mlir/include/mlir/IR/Instructions.h b/mlir/include/mlir/IR/Instructions.h
index fb6b1b97ca0..71d832b8b90 100644
--- a/mlir/include/mlir/IR/Instructions.h
+++ b/mlir/include/mlir/IR/Instructions.h
@@ -794,6 +794,130 @@ private:
friend class ForInst;
};
+
+/// If instruction restricts execution to a subset of the loop iteration space.
+class IfInst : public Instruction {
+public:
+ static IfInst *create(Location location, ArrayRef<Value *> operands,
+ IntegerSet set);
+ ~IfInst();
+
+ //===--------------------------------------------------------------------===//
+ // Then, else, condition.
+ //===--------------------------------------------------------------------===//
+
+ Block *getThen() { return &thenClause.front(); }
+ const Block *getThen() const { return &thenClause.front(); }
+ Block *getElse() { return elseClause ? &elseClause->front() : nullptr; }
+ const Block *getElse() const {
+ return elseClause ? &elseClause->front() : nullptr;
+ }
+ bool hasElse() const { return elseClause != nullptr; }
+
+ Block *createElse() {
+ assert(elseClause == nullptr && "already has an else clause!");
+ elseClause = new BlockList(this);
+ elseClause->push_back(new Block());
+ return &elseClause->front();
+ }
+
+ const AffineCondition getCondition() const;
+
+ IntegerSet getIntegerSet() const { return set; }
+ void setIntegerSet(IntegerSet newSet) {
+ assert(newSet.getNumOperands() == operands.size());
+ set = newSet;
+ }
+
+ //===--------------------------------------------------------------------===//
+ // Operands
+ //===--------------------------------------------------------------------===//
+
+ /// Operand iterators.
+ using operand_iterator = OperandIterator<IfInst, Value>;
+ using const_operand_iterator = OperandIterator<const IfInst, const Value>;
+
+ /// Operand iterator range.
+ using operand_range = llvm::iterator_range<operand_iterator>;
+ using const_operand_range = llvm::iterator_range<const_operand_iterator>;
+
+ unsigned getNumOperands() const { return operands.size(); }
+
+ Value *getOperand(unsigned idx) { return getInstOperand(idx).get(); }
+ const Value *getOperand(unsigned idx) const {
+ return getInstOperand(idx).get();
+ }
+ void setOperand(unsigned idx, Value *value) {
+ getInstOperand(idx).set(value);
+ }
+
+ operand_iterator operand_begin() { return operand_iterator(this, 0); }
+ operand_iterator operand_end() {
+ return operand_iterator(this, getNumOperands());
+ }
+
+ const_operand_iterator operand_begin() const {
+ return const_operand_iterator(this, 0);
+ }
+ const_operand_iterator operand_end() const {
+ return const_operand_iterator(this, getNumOperands());
+ }
+
+ ArrayRef<InstOperand> getInstOperands() const { return operands; }
+ MutableArrayRef<InstOperand> getInstOperands() { return operands; }
+ InstOperand &getInstOperand(unsigned idx) { return getInstOperands()[idx]; }
+ const InstOperand &getInstOperand(unsigned idx) const {
+ return getInstOperands()[idx];
+ }
+
+ //===--------------------------------------------------------------------===//
+ // Other
+ //===--------------------------------------------------------------------===//
+
+ MLIRContext *getContext() const;
+
+ /// Methods for support type inquiry through isa, cast, and dyn_cast.
+ static bool classof(const IROperandOwner *ptr) {
+ return ptr->getKind() == IROperandOwner::Kind::IfInst;
+ }
+
+private:
+ // it is always present.
+ BlockList thenClause;
+ // 'else' clause of the if instruction. 'nullptr' if there is no else clause.
+ BlockList *elseClause;
+
+ // The integer set capturing the conditional guard.
+ IntegerSet set;
+
+ // Condition operands.
+ std::vector<InstOperand> operands;
+
+ explicit IfInst(Location location, unsigned numOperands, IntegerSet set);
+};
+
+/// AffineCondition represents a condition of the 'if' instruction.
+/// Its life span should not exceed that of the objects it refers to.
+/// AffineCondition does not provide its own methods for iterating over
+/// the operands since the iterators of the if instruction accomplish
+/// the same purpose.
+///
+/// AffineCondition is trivially copyable, so it should be passed by value.
+class AffineCondition {
+public:
+ const IfInst *getIfInst() const { return &inst; }
+ IntegerSet getIntegerSet() const { return set; }
+
+private:
+ // 'if' instruction that contains this affine condition.
+ const IfInst &inst;
+ // Integer set for this affine condition.
+ IntegerSet set;
+
+ AffineCondition(const IfInst &inst, IntegerSet set) : inst(inst), set(set) {}
+
+ friend class IfInst;
+};
} // end namespace mlir
#endif // MLIR_IR_INSTRUCTIONS_H
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index d3a5d35427f..1e319db3571 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -89,9 +89,6 @@ public:
/// Print the entire operation with the default generic assembly form.
virtual void printGenericOp(const OperationInst *op) = 0;
- /// Prints a block list.
- virtual void printBlockList(const BlockList &blocks) = 0;
-
private:
OpAsmPrinter(const OpAsmPrinter &) = delete;
void operator=(const OpAsmPrinter &) = delete;
@@ -198,19 +195,7 @@ public:
virtual bool parseColonTypeList(SmallVectorImpl<Type> &result) = 0;
/// Parse a keyword followed by a type.
- bool parseKeywordType(const char *keyword, Type &result) {
- return parseKeyword(keyword) || parseType(result);
- }
-
- /// Parse a keyword.
- bool parseKeyword(const char *keyword) {
- if (parseOptionalKeyword(keyword))
- return emitError(getNameLoc(), "expected '" + Twine(keyword) + "'");
- return false;
- }
-
- /// If a keyword is present, then parse it.
- virtual bool parseOptionalKeyword(const char *keyword) = 0;
+ virtual bool parseKeywordType(const char *keyword, Type &result) = 0;
/// Add the specified type to the end of the specified type list and return
/// false. This is a helper designed to allow parse methods to be simple and
@@ -311,10 +296,6 @@ public:
int requiredOperandCount = -1,
Delimiter delimiter = Delimiter::None) = 0;
- /// Parses a block list. Any parsed blocks are filled in to the
- /// operation's block lists after the operation is created.
- virtual bool parseBlockList() = 0;
-
//===--------------------------------------------------------------------===//
// Methods for interacting with the parser
//===--------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/UseDefLists.h b/mlir/include/mlir/IR/UseDefLists.h
index 80cd21362ce..053d3520103 100644
--- a/mlir/include/mlir/IR/UseDefLists.h
+++ b/mlir/include/mlir/IR/UseDefLists.h
@@ -81,9 +81,10 @@ public:
enum class Kind {
OperationInst,
ForInst,
+ IfInst,
/// These enums define ranges used for classof implementations.
- INST_LAST = ForInst,
+ INST_LAST = IfInst,
};
Kind getKind() const { return locationAndKind.getInt(); }
diff --git a/mlir/include/mlir/Transforms/MLPatternLoweringPass.h b/mlir/include/mlir/Transforms/MLPatternLoweringPass.h
index 00c6577240c..978fa45ab23 100644
--- a/mlir/include/mlir/Transforms/MLPatternLoweringPass.h
+++ b/mlir/include/mlir/Transforms/MLPatternLoweringPass.h
@@ -93,7 +93,7 @@ using OwningMLLoweringPatternList =
/// next _original_ operation is considered.
/// In other words, for each operation, the pass applies the first matching
/// rewriter in the list and advances to the (lexically) next operation.
-/// Non-operation instructions (ForInst) are ignored.
+/// Non-operation instructions (ForInst and IfInst) are ignored.
/// This is similar to greedy worklist-based pattern rewriter, except that this
/// operates on ML functions using an ML builder and does not maintain the work
/// list. Note that, as of the time of writing, worklist-based rewriter did not
diff --git a/mlir/lib/AffineOps/AffineOps.cpp b/mlir/lib/AffineOps/AffineOps.cpp
deleted file mode 100644
index 5b29467fc44..00000000000
--- a/mlir/lib/AffineOps/AffineOps.cpp
+++ /dev/null
@@ -1,151 +0,0 @@
-//===- AffineOps.cpp - MLIR Affine Operations -----------------------------===//
-//
-// Copyright 2019 The MLIR Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
-
-#include "mlir/AffineOps/AffineOps.h"
-#include "mlir/IR/Block.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/OpImplementation.h"
-using namespace mlir;
-
-//===----------------------------------------------------------------------===//
-// AffineOpsDialect
-//===----------------------------------------------------------------------===//
-
-AffineOpsDialect::AffineOpsDialect(MLIRContext *context)
- : Dialect(/*namePrefix=*/"", context) {
- addOperations<AffineIfOp>();
-}
-
-//===----------------------------------------------------------------------===//
-// AffineIfOp
-//===----------------------------------------------------------------------===//
-
-void AffineIfOp::build(Builder *builder, OperationState *result,
- IntegerSet condition,
- ArrayRef<Value *> conditionOperands) {
- result->addAttribute(getConditionAttrName(), IntegerSetAttr::get(condition));
- result->addOperands(conditionOperands);
-
- // Reserve 2 block lists, one for the 'then' and one for the 'else' regions.
- result->reserveBlockLists(2);
-}
-
-bool AffineIfOp::verify() const {
- // Verify that we have a condition attribute.
- auto conditionAttr = getAttrOfType<IntegerSetAttr>(getConditionAttrName());
- if (!conditionAttr)
- return emitOpError("requires an integer set attribute named 'condition'");
-
- // Verify that the operands are valid dimension/symbols.
- IntegerSet condition = conditionAttr.getValue();
- for (unsigned i = 0, e = getNumOperands(); i != e; ++i) {
- const Value *operand = getOperand(i);
- if (i < condition.getNumDims() && !operand->isValidDim())
- return emitOpError("operand cannot be used as a dimension id");
- if (i >= condition.getNumDims() && !operand->isValidSymbol())
- return emitOpError("operand cannot be used as a symbol");
- }
-
- // Verify that the entry of each child blocklist does not have arguments.
- for (const auto &blockList : getInstruction()->getBlockLists()) {
- if (blockList.empty())
- continue;
-
- // TODO(riverriddle) We currently do not allow multiple blocks in child
- // block lists.
- if (std::next(blockList.begin()) != blockList.end())
- return emitOpError(
- "expects only one block per 'if' or 'else' block list");
- if (blockList.front().getTerminator())
- return emitOpError("expects region block to not have a terminator");
-
- for (const auto &b : blockList)
- if (b.getNumArguments() != 0)
- return emitOpError(
- "requires that child entry blocks have no arguments");
- }
- return false;
-}
-
-bool AffineIfOp::parse(OpAsmParser *parser, OperationState *result) {
- // Parse the condition attribute set.
- IntegerSetAttr conditionAttr;
- unsigned numDims;
- if (parser->parseAttribute(conditionAttr, getConditionAttrName().data(),
- result->attributes) ||
- parseDimAndSymbolList(parser, result->operands, numDims))
- return true;
-
- // Verify the condition operands.
- auto set = conditionAttr.getValue();
- if (set.getNumDims() != numDims)
- return parser->emitError(
- parser->getNameLoc(),
- "dim operand count and integer set dim count must match");
- if (numDims + set.getNumSymbols() != result->operands.size())
- return parser->emitError(
- parser->getNameLoc(),
- "symbol operand count and integer set symbol count must match");
-
- // Parse the 'then' block list.
- if (parser->parseBlockList())
- return true;
-
- // If we find an 'else' keyword then parse the else block list.
- if (!parser->parseOptionalKeyword("else")) {
- if (parser->parseBlockList())
- return true;
- }
-
- // Reserve 2 block lists, one for the 'then' and one for the 'else' regions.
- result->reserveBlockLists(2);
- return false;
-}
-
-void AffineIfOp::print(OpAsmPrinter *p) const {
- auto conditionAttr = getAttrOfType<IntegerSetAttr>(getConditionAttrName());
- *p << "if " << conditionAttr;
- printDimAndSymbolList(operand_begin(), operand_end(),
- conditionAttr.getValue().getNumDims(), p);
- p->printBlockList(getInstruction()->getBlockList(0));
-
- // Print the 'else' block list if it has any blocks.
- const auto &elseBlockList = getInstruction()->getBlockList(1);
- if (!elseBlockList.empty()) {
- *p << " else";
- p->printBlockList(elseBlockList);
- }
-}
-
-IntegerSet AffineIfOp::getIntegerSet() const {
- return getAttrOfType<IntegerSetAttr>(getConditionAttrName()).getValue();
-}
-void AffineIfOp::setIntegerSet(IntegerSet newSet) {
- setAttr(
- Identifier::get(getConditionAttrName(), getInstruction()->getContext()),
- IntegerSetAttr::get(newSet));
-}
-
-/// Returns the list of 'then' blocks.
-BlockList &AffineIfOp::getThenBlocks() {
- return getInstruction()->getBlockList(0);
-}
-
-/// Returns the list of 'else' blocks.
-BlockList &AffineIfOp::getElseBlocks() {
- return getInstruction()->getBlockList(1);
-}
diff --git a/mlir/lib/AffineOps/DialectRegistration.cpp b/mlir/lib/AffineOps/DialectRegistration.cpp
deleted file mode 100644
index 0afb32c1bd6..00000000000
--- a/mlir/lib/AffineOps/DialectRegistration.cpp
+++ /dev/null
@@ -1,22 +0,0 @@
-//===- DialectRegistration.cpp - Register Affine Op dialect ---------------===//
-//
-// Copyright 2019 The MLIR Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
-
-#include "mlir/AffineOps/AffineOps.h"
-using namespace mlir;
-
-// Static initialization for Affine op dialect registration.
-static DialectRegistration<AffineOpsDialect> StandardOps;
diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp
index 07c903a6613..219f356807a 100644
--- a/mlir/lib/Analysis/LoopAnalysis.cpp
+++ b/mlir/lib/Analysis/LoopAnalysis.cpp
@@ -21,7 +21,6 @@
#include "mlir/Analysis/LoopAnalysis.h"
-#include "mlir/AffineOps/AffineOps.h"
#include "mlir/Analysis/AffineAnalysis.h"
#include "mlir/Analysis/AffineStructures.h"
#include "mlir/Analysis/NestedMatcher.h"
@@ -247,16 +246,6 @@ static bool isVectorizableLoopWithCond(const ForInst &loop,
return false;
}
- // No vectorization across unknown regions.
- auto regions = matcher::Op([](const Instruction &inst) -> bool {
- auto &opInst = cast<OperationInst>(inst);
- return opInst.getNumBlockLists() != 0 && !opInst.isa<AffineIfOp>();
- });
- auto regionsMatched = regions.match(forInst);
- if (!regionsMatched.empty()) {
- return false;
- }
-
auto vectorTransfers = matcher::Op(isVectorTransferReadOrWrite);
auto vectorTransfersMatched = vectorTransfers.match(forInst);
if (!vectorTransfersMatched.empty()) {
diff --git a/mlir/lib/Analysis/NestedMatcher.cpp b/mlir/lib/Analysis/NestedMatcher.cpp
index 491a9bef1b9..4f32e9b22f4 100644
--- a/mlir/lib/Analysis/NestedMatcher.cpp
+++ b/mlir/lib/Analysis/NestedMatcher.cpp
@@ -16,7 +16,6 @@
// =============================================================================
#include "mlir/Analysis/NestedMatcher.h"
-#include "mlir/AffineOps/AffineOps.h"
#include "mlir/StandardOps/StandardOps.h"
#include "llvm/ADT/ArrayRef.h"
@@ -187,11 +186,6 @@ FilterFunctionType NestedPattern::getFilterFunction() {
return storage->filter;
}
-static bool isAffineIfOp(const Instruction &inst) {
- return isa<OperationInst>(inst) &&
- cast<OperationInst>(inst).isa<AffineIfOp>();
-}
-
namespace mlir {
namespace matcher {
@@ -200,22 +194,16 @@ NestedPattern Op(FilterFunctionType filter) {
}
NestedPattern If(NestedPattern child) {
- return NestedPattern(Instruction::Kind::OperationInst, child, isAffineIfOp);
+ return NestedPattern(Instruction::Kind::If, child, defaultFilterFunction);
}
NestedPattern If(FilterFunctionType filter, NestedPattern child) {
- return NestedPattern(Instruction::Kind::OperationInst, child,
- [filter](const Instruction &inst) {
- return isAffineIfOp(inst) && filter(inst);
- });
+ return NestedPattern(Instruction::Kind::If, child, filter);
}
NestedPattern If(ArrayRef<NestedPattern> nested) {
- return NestedPattern(Instruction::Kind::OperationInst, nested, isAffineIfOp);
+ return NestedPattern(Instruction::Kind::If, nested, defaultFilterFunction);
}
NestedPattern If(FilterFunctionType filter, ArrayRef<NestedPattern> nested) {
- return NestedPattern(Instruction::Kind::OperationInst, nested,
- [filter](const Instruction &inst) {
- return isAffineIfOp(inst) && filter(inst);
- });
+ return NestedPattern(Instruction::Kind::If, nested, filter);
}
NestedPattern For(NestedPattern child) {
diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp
index 0e77d4d9084..939a2ede618 100644
--- a/mlir/lib/Analysis/Utils.cpp
+++ b/mlir/lib/Analysis/Utils.cpp
@@ -22,7 +22,6 @@
#include "mlir/Analysis/Utils.h"
-#include "mlir/AffineOps/AffineOps.h"
#include "mlir/Analysis/AffineAnalysis.h"
#include "mlir/Analysis/AffineStructures.h"
#include "mlir/IR/Builders.h"
@@ -44,7 +43,7 @@ void mlir::getLoopIVs(const Instruction &inst,
// Traverse up the hierarchy collecing all 'for' instruction while skipping
// over 'if' instructions.
while (currInst && ((currForInst = dyn_cast<ForInst>(currInst)) ||
- cast<OperationInst>(currInst)->isa<AffineIfOp>())) {
+ isa<IfInst>(currInst))) {
if (currForInst)
loops->push_back(currForInst);
currInst = currInst->getParentInst();
@@ -360,12 +359,21 @@ static Instruction *getInstAtPosition(ArrayRef<unsigned> positions,
if (auto *childForInst = dyn_cast<ForInst>(&inst))
return getInstAtPosition(positions, level + 1, childForInst->getBody());
- for (auto &blockList : cast<OperationInst>(&inst)->getBlockLists()) {
- for (auto &b : blockList)
- if (auto *ret = getInstAtPosition(positions, level + 1, &b))
- return ret;
+ if (auto *ifInst = dyn_cast<IfInst>(&inst)) {
+ auto *ret = getInstAtPosition(positions, level + 1, ifInst->getThen());
+ if (ret != nullptr)
+ return ret;
+ if (auto *elseClause = ifInst->getElse())
+ return getInstAtPosition(positions, level + 1, elseClause);
+ }
+ if (auto *opInst = dyn_cast<OperationInst>(&inst)) {
+ for (auto &blockList : opInst->getBlockLists()) {
+ for (auto &b : blockList)
+ if (auto *ret = getInstAtPosition(positions, level + 1, &b))
+ return ret;
+ }
+ return nullptr;
}
- return nullptr;
}
return nullptr;
}
diff --git a/mlir/lib/Analysis/Verifier.cpp b/mlir/lib/Analysis/Verifier.cpp
index 474eeb2a28e..383a4878c35 100644
--- a/mlir/lib/Analysis/Verifier.cpp
+++ b/mlir/lib/Analysis/Verifier.cpp
@@ -73,6 +73,7 @@ public:
bool verifyBlock(const Block &block, bool isTopLevel);
bool verifyOperation(const OperationInst &op);
bool verifyForInst(const ForInst &forInst);
+ bool verifyIfInst(const IfInst &ifInst);
bool verifyDominance(const Block &block);
bool verifyInstDominance(const Instruction &inst);
@@ -179,6 +180,10 @@ bool FuncVerifier::verifyBlock(const Block &block, bool isTopLevel) {
if (verifyForInst(cast<ForInst>(inst)))
return true;
break;
+ case Instruction::Kind::If:
+ if (verifyIfInst(cast<IfInst>(inst)))
+ return true;
+ break;
}
}
@@ -245,6 +250,18 @@ bool FuncVerifier::verifyForInst(const ForInst &forInst) {
return verifyBlock(*forInst.getBody(), /*isTopLevel=*/false);
}
+bool FuncVerifier::verifyIfInst(const IfInst &ifInst) {
+ // TODO: check that if conditions are properly formed.
+ if (verifyBlock(*ifInst.getThen(), /*isTopLevel*/ false))
+ return true;
+
+ if (auto *elseClause = ifInst.getElse())
+ if (verifyBlock(*elseClause, /*isTopLevel*/ false))
+ return true;
+
+ return false;
+}
+
bool FuncVerifier::verifyDominance(const Block &block) {
for (auto &inst : block) {
// Check that all operands on the instruction are ok.
@@ -266,6 +283,14 @@ bool FuncVerifier::verifyDominance(const Block &block) {
if (verifyDominance(*cast<ForInst>(inst).getBody()))
return true;
break;
+ case Instruction::Kind::If:
+ auto &ifInst = cast<IfInst>(inst);
+ if (verifyDominance(*ifInst.getThen()))
+ return true;
+ if (auto *elseClause = ifInst.getElse())
+ if (verifyDominance(*elseClause))
+ return true;
+ break;
}
}
return false;
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index cb4c1f0edce..21bc3b824b1 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -145,6 +145,7 @@ private:
// Visit functions.
void visitInstruction(const Instruction *inst);
void visitForInst(const ForInst *forInst);
+ void visitIfInst(const IfInst *ifInst);
void visitOperationInst(const OperationInst *opInst);
void visitType(Type type);
void visitAttribute(Attribute attr);
@@ -196,6 +197,10 @@ void ModuleState::visitAttribute(Attribute attr) {
}
}
+void ModuleState::visitIfInst(const IfInst *ifInst) {
+ recordIntegerSetReference(ifInst->getIntegerSet());
+}
+
void ModuleState::visitForInst(const ForInst *forInst) {
AffineMap lbMap = forInst->getLowerBoundMap();
if (!hasCustomForm(lbMap))
@@ -220,6 +225,8 @@ void ModuleState::visitOperationInst(const OperationInst *op) {
void ModuleState::visitInstruction(const Instruction *inst) {
switch (inst->getKind()) {
+ case Instruction::Kind::If:
+ return visitIfInst(cast<IfInst>(inst));
case Instruction::Kind::For:
return visitForInst(cast<ForInst>(inst));
case Instruction::Kind::OperationInst:
@@ -1070,6 +1077,7 @@ public:
void print(const Instruction *inst);
void print(const OperationInst *inst);
void print(const ForInst *inst);
+ void print(const IfInst *inst);
void print(const Block *block, bool printBlockArgs = true);
void printOperation(const OperationInst *op);
@@ -1117,9 +1125,6 @@ public:
unsigned index) override;
/// Print a block list.
- void printBlockList(const BlockList &blocks) override {
- printBlockList(blocks, /*printEntryBlockArgs=*/true);
- }
void printBlockList(const BlockList &blocks, bool printEntryBlockArgs) {
os << " {\n";
if (!blocks.empty()) {
@@ -1209,6 +1214,12 @@ void FunctionPrinter::numberValuesInBlock(const Block &block) {
// Recursively number the stuff in the body.
numberValuesInBlock(*cast<ForInst>(&inst)->getBody());
break;
+ case Instruction::Kind::If: {
+ auto *ifInst = cast<IfInst>(&inst);
+ numberValuesInBlock(*ifInst->getThen());
+ if (auto *elseBlock = ifInst->getElse())
+ numberValuesInBlock(*elseBlock);
+ }
}
}
}
@@ -1349,7 +1360,8 @@ void FunctionPrinter::printFunctionSignature() {
}
void FunctionPrinter::print(const Block *block, bool printBlockArgs) {
- // Print the block label and argument list if requested.
+ // Print the block label and argument list, unless this is the first block of
+ // the function, or the first block of an IfInst/ForInst with no arguments.
if (printBlockArgs) {
os.indent(currentIndent);
printBlockName(block);
@@ -1406,6 +1418,8 @@ void FunctionPrinter::print(const Instruction *inst) {
return print(cast<OperationInst>(inst));
case Instruction::Kind::For:
return print(cast<ForInst>(inst));
+ case Instruction::Kind::If:
+ return print(cast<IfInst>(inst));
}
}
@@ -1433,6 +1447,22 @@ void FunctionPrinter::print(const ForInst *inst) {
os.indent(currentIndent) << "}";
}
+void FunctionPrinter::print(const IfInst *inst) {
+ os.indent(currentIndent) << "if ";
+ IntegerSet set = inst->getIntegerSet();
+ printIntegerSetReference(set);
+ printDimAndSymbolList(inst->getInstOperands(), set.getNumDims());
+ printTrailingLocation(inst->getLoc());
+ os << " {\n";
+ print(inst->getThen(), /*printBlockArgs=*/false);
+ os.indent(currentIndent) << "}";
+ if (inst->hasElse()) {
+ os << " else {\n";
+ print(inst->getElse(), /*printBlockArgs=*/false);
+ os.indent(currentIndent) << "}";
+ }
+}
+
void FunctionPrinter::printValueID(const Value *value,
bool printResultNo) const {
int resultNo = -1;
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index e174fdc1d00..4471ff25e94 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -327,3 +327,10 @@ ForInst *FuncBuilder::createFor(Location location, int64_t lb, int64_t ub,
auto ubMap = AffineMap::getConstantMap(ub, context);
return createFor(location, {}, lbMap, {}, ubMap, step);
}
+
+IfInst *FuncBuilder::createIf(Location location, ArrayRef<Value *> operands,
+ IntegerSet set) {
+ auto *inst = IfInst::create(location, operands, set);
+ block->getInstructions().insert(insertPoint, inst);
+ return inst;
+}
diff --git a/mlir/lib/IR/Instruction.cpp b/mlir/lib/IR/Instruction.cpp
index 0ccab2305ec..6d74ed14257 100644
--- a/mlir/lib/IR/Instruction.cpp
+++ b/mlir/lib/IR/Instruction.cpp
@@ -73,6 +73,9 @@ void Instruction::destroy() {
case Kind::For:
delete cast<ForInst>(this);
break;
+ case Kind::If:
+ delete cast<IfInst>(this);
+ break;
}
}
@@ -138,6 +141,8 @@ unsigned Instruction::getNumOperands() const {
return cast<OperationInst>(this)->getNumOperands();
case Kind::For:
return cast<ForInst>(this)->getNumOperands();
+ case Kind::If:
+ return cast<IfInst>(this)->getNumOperands();
}
}
@@ -147,6 +152,8 @@ MutableArrayRef<InstOperand> Instruction::getInstOperands() {
return cast<OperationInst>(this)->getInstOperands();
case Kind::For:
return cast<ForInst>(this)->getInstOperands();
+ case Kind::If:
+ return cast<IfInst>(this)->getInstOperands();
}
}
@@ -280,6 +287,15 @@ void Instruction::dropAllReferences() {
// Make sure to drop references held by instructions within the body.
cast<ForInst>(this)->getBody()->dropAllReferences();
break;
+ case Kind::If: {
+ // Make sure to drop references held by instructions within the 'then' and
+ // 'else' blocks.
+ auto *ifInst = cast<IfInst>(this);
+ ifInst->getThen()->dropAllReferences();
+ if (auto *elseBlock = ifInst->getElse())
+ elseBlock->dropAllReferences();
+ break;
+ }
case Kind::OperationInst: {
auto *opInst = cast<OperationInst>(this);
if (isTerminator())
@@ -794,6 +810,54 @@ mlir::extractForInductionVars(ArrayRef<ForInst *> forInsts) {
return results;
}
//===----------------------------------------------------------------------===//
+// IfInst
+//===----------------------------------------------------------------------===//
+
+IfInst::IfInst(Location location, unsigned numOperands, IntegerSet set)
+ : Instruction(Kind::If, location), thenClause(this), elseClause(nullptr),
+ set(set) {
+ operands.reserve(numOperands);
+
+ // The then of an 'if' inst always has one block.
+ thenClause.push_back(new Block());
+}
+
+IfInst::~IfInst() {
+ if (elseClause)
+ delete elseClause;
+
+ // An IfInst's IntegerSet 'set' should not be deleted since it is
+ // allocated through MLIRContext's bump pointer allocator.
+}
+
+IfInst *IfInst::create(Location location, ArrayRef<Value *> operands,
+ IntegerSet set) {
+ unsigned numOperands = operands.size();
+ assert(numOperands == set.getNumOperands() &&
+ "operand cound does not match the integer set operand count");
+
+ IfInst *inst = new IfInst(location, numOperands, set);
+
+ for (auto *op : operands)
+ inst->operands.emplace_back(InstOperand(inst, op));
+
+ return inst;
+}
+
+const AffineCondition IfInst::getCondition() const {
+ return AffineCondition(*this, set);
+}
+
+MLIRContext *IfInst::getContext() const {
+ // Check for degenerate case of if instruction with no operands.
+ // This is unlikely, but legal.
+ if (operands.empty())
+ return getFunction()->getContext();
+
+ return getOperand(0)->getType().getContext();
+}
+
+//===----------------------------------------------------------------------===//
// Instruction Cloning
//===----------------------------------------------------------------------===//
@@ -867,23 +931,40 @@ Instruction *Instruction::clone(BlockAndValueMapping &mapper,
for (auto *opValue : getOperands())
operands.push_back(mapper.lookupOrDefault(const_cast<Value *>(opValue)));
- // Otherwise, this must be a ForInst.
- auto *forInst = cast<ForInst>(this);
- auto lbMap = forInst->getLowerBoundMap();
- auto ubMap = forInst->getUpperBoundMap();
+ if (auto *forInst = dyn_cast<ForInst>(this)) {
+ auto lbMap = forInst->getLowerBoundMap();
+ auto ubMap = forInst->getUpperBoundMap();
- auto *newFor = ForInst::create(
- getLoc(), ArrayRef<Value *>(operands).take_front(lbMap.getNumInputs()),
- lbMap, ArrayRef<Value *>(operands).take_back(ubMap.getNumInputs()), ubMap,
- forInst->getStep());
+ auto *newFor = ForInst::create(
+ getLoc(), ArrayRef<Value *>(operands).take_front(lbMap.getNumInputs()),
+ lbMap, ArrayRef<Value *>(operands).take_back(ubMap.getNumInputs()),
+ ubMap, forInst->getStep());
- // Remember the induction variable mapping.
- mapper.map(forInst->getInductionVar(), newFor->getInductionVar());
+ // Remember the induction variable mapping.
+ mapper.map(forInst->getInductionVar(), newFor->getInductionVar());
+
+ // Recursively clone the body of the for loop.
+ for (auto &subInst : *forInst->getBody())
+ newFor->getBody()->push_back(subInst.clone(mapper, context));
+
+ return newFor;
+ }
+
+ // Otherwise, we must have an If instruction.
+ auto *ifInst = cast<IfInst>(this);
+ auto *newIf = IfInst::create(getLoc(), operands, ifInst->getIntegerSet());
+
+ auto *resultThen = newIf->getThen();
+ for (auto &childInst : *ifInst->getThen())
+ resultThen->push_back(childInst.clone(mapper, context));
+
+ if (ifInst->hasElse()) {
+ auto *resultElse = newIf->createElse();
+ for (auto &childInst : *ifInst->getElse())
+ resultElse->push_back(childInst.clone(mapper, context));
+ }
- // Recursively clone the body of the for loop.
- for (auto &subInst : *forInst->getBody())
- newFor->getBody()->push_back(subInst.clone(mapper, context));
- return newFor;
+ return newIf;
}
Instruction *Instruction::clone(MLIRContext *context) const {
diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index 2ab151f8913..099b218892f 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -281,7 +281,7 @@ bool OpTrait::impl::verifyIsTerminator(const OperationInst *op) {
if (!block || &block->back() != op)
return op->emitOpError("must be the last instruction in the parent block");
- // TODO(riverriddle) Terminators may not exist with an operation region.
+ // Terminators may not exist in ForInst and IfInst.
if (block->getContainingInst())
return op->emitOpError("may only be at the top level of a function");
diff --git a/mlir/lib/IR/Value.cpp b/mlir/lib/IR/Value.cpp
index 7103eeb7389..6418b062dc1 100644
--- a/mlir/lib/IR/Value.cpp
+++ b/mlir/lib/IR/Value.cpp
@@ -66,6 +66,8 @@ MLIRContext *IROperandOwner::getContext() const {
return cast<OperationInst>(this)->getContext();
case Kind::ForInst:
return cast<ForInst>(this)->getContext();
+ case Kind::IfInst:
+ return cast<IfInst>(this)->getContext();
}
}
diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index e5d6aa46565..c477ad1bbc5 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -996,7 +996,8 @@ Attribute Parser::parseAttribute(Type type) {
AffineMap map;
IntegerSet set;
if (parseAffineMapOrIntegerSetReference(map, set))
- return nullptr;
+ return (emitError("expected affine map or integer set attribute value"),
+ nullptr);
if (map)
return builder.getAffineMapAttr(map);
assert(set);
@@ -2208,6 +2209,8 @@ public:
const char *affineStructName);
ParseResult parseBound(SmallVectorImpl<Value *> &operands, AffineMap &map,
bool isLower);
+ ParseResult parseIfInst();
+ ParseResult parseElseClause(Block *elseClause);
ParseResult parseInstructions(Block *block);
private:
@@ -2389,6 +2392,10 @@ ParseResult FunctionParser::parseBlockBody(Block *block) {
if (parseForInst())
return ParseFailure;
break;
+ case Token::kw_if:
+ if (parseIfInst())
+ return ParseFailure;
+ break;
}
}
@@ -2928,18 +2935,12 @@ public:
return false;
}
- /// Parse an optional keyword.
- bool parseOptionalKeyword(const char *keyword) override {
- // Check that the current token is a bare identifier or keyword.
- if (parser.getToken().isNot(Token::bare_identifier) &&
- !parser.getToken().isKeyword())
- return true;
-
- if (parser.getTokenSpelling() == keyword) {
- parser.consumeToken();
- return false;
- }
- return true;
+ /// Parse a keyword followed by a type.
+ bool parseKeywordType(const char *keyword, Type &result) override {
+ if (parser.getTokenSpelling() != keyword)
+ return parser.emitError("expected '" + Twine(keyword) + "'");
+ parser.consumeToken();
+ return !(result = parser.parseType());
}
/// Parse an arbitrary attribute of a given type and return it in result. This
@@ -3077,15 +3078,6 @@ public:
return result == nullptr;
}
- /// Parses a list of blocks.
- bool parseBlockList() override {
- SmallVector<Block *, 2> results;
- if (parser.parseOperationBlockList(results))
- return true;
- parsedBlockLists.emplace_back(results);
- return false;
- }
-
//===--------------------------------------------------------------------===//
// Methods for interacting with the parser
//===--------------------------------------------------------------------===//
@@ -3107,11 +3099,6 @@ public:
/// Emit a diagnostic at the specified location and return true.
bool emitError(llvm::SMLoc loc, const Twine &message) override {
- // If we emit an error, then cleanup any parsed block lists.
- for (auto &blockList : parsedBlockLists)
- parser.cleanupInvalidBlocks(blockList);
- parsedBlockLists.clear();
-
parser.emitError(loc, "custom op '" + Twine(opName) + "' " + message);
emittedError = true;
return true;
@@ -3119,13 +3106,7 @@ public:
bool didEmitError() const { return emittedError; }
- /// Returns the block lists that were parsed.
- MutableArrayRef<SmallVector<Block *, 2>> getParsedBlockLists() {
- return parsedBlockLists;
- }
-
private:
- std::vector<SmallVector<Block *, 2>> parsedBlockLists;
SMLoc nameLoc;
StringRef opName;
FunctionParser &parser;
@@ -3164,25 +3145,8 @@ OperationInst *FunctionParser::parseCustomOperation() {
if (opAsmParser.didEmitError())
return nullptr;
- // Check that enough block lists were reserved for those that were parsed.
- auto parsedBlockLists = opAsmParser.getParsedBlockLists();
- if (parsedBlockLists.size() > opState.numBlockLists) {
- opAsmParser.emitError(
- opLoc,
- "parsed more block lists than those reserved in the operation state");
- return nullptr;
- }
-
// Otherwise, we succeeded. Use the state it parsed as our op information.
- auto *opInst = builder.createOperation(opState);
-
- // Resolve any parsed block lists.
- for (unsigned i = 0, e = parsedBlockLists.size(); i != e; ++i) {
- auto &opBlockList = opInst->getBlockList(i).getBlocks();
- opBlockList.insert(opBlockList.end(), parsedBlockLists[i].begin(),
- parsedBlockLists[i].end());
- }
- return opInst;
+ return builder.createOperation(opState);
}
/// For instruction.
@@ -3474,6 +3438,69 @@ IntegerSet AffineParser::parseIntegerSetConstraints(unsigned numDims,
return builder.getIntegerSet(numDims, numSymbols, constraints, isEqs);
}
+/// If instruction.
+///
+/// ml-if-head ::= `if` ml-if-cond trailing-location? `{` inst* `}`
+/// | ml-if-head `else` `if` ml-if-cond trailing-location?
+/// `{` inst* `}`
+/// ml-if-inst ::= ml-if-head
+/// | ml-if-head `else` `{` inst* `}`
+///
+ParseResult FunctionParser::parseIfInst() {
+ auto loc = getToken().getLoc();
+ consumeToken(Token::kw_if);
+
+ IntegerSet set = parseIntegerSetReference();
+ if (!set)
+ return ParseFailure;
+
+ SmallVector<Value *, 4> operands;
+ if (parseDimAndSymbolList(operands, set.getNumDims(), set.getNumOperands(),
+ "integer set"))
+ return ParseFailure;
+
+ IfInst *ifInst =
+ builder.createIf(getEncodedSourceLocation(loc), operands, set);
+
+ // Try to parse the optional trailing location.
+ if (parseOptionalTrailingLocation(ifInst))
+ return ParseFailure;
+
+ Block *thenClause = ifInst->getThen();
+
+ // When parsing of an if instruction body fails, the IR contains
+ // the if instruction with the portion of the body that has been
+ // successfully parsed.
+ if (parseToken(Token::l_brace, "expected '{' before instruction list") ||
+ parseBlock(thenClause) ||
+ parseToken(Token::r_brace, "expected '}' after instruction list"))
+ return ParseFailure;
+
+ if (consumeIf(Token::kw_else)) {
+ auto *elseClause = ifInst->createElse();
+ if (parseElseClause(elseClause))
+ return ParseFailure;
+ }
+
+ // Reset insertion point to the current block.
+ builder.setInsertionPointToEnd(ifInst->getBlock());
+
+ return ParseSuccess;
+}
+
+ParseResult FunctionParser::parseElseClause(Block *elseClause) {
+ if (getToken().is(Token::kw_if)) {
+ builder.setInsertionPointToEnd(elseClause);
+ return parseIfInst();
+ }
+
+ if (parseToken(Token::l_brace, "expected '{' before instruction list") ||
+ parseBlock(elseClause) ||
+ parseToken(Token::r_brace, "expected '}' after instruction list"))
+ return ParseFailure;
+ return ParseSuccess;
+}
+
//===----------------------------------------------------------------------===//
// Top-level entity parsing.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Parser/TokenKinds.def b/mlir/lib/Parser/TokenKinds.def
index ec00f98b3f5..40e98b25cb3 100644
--- a/mlir/lib/Parser/TokenKinds.def
+++ b/mlir/lib/Parser/TokenKinds.def
@@ -91,6 +91,7 @@ TOK_KEYWORD(attributes)
TOK_KEYWORD(bf16)
TOK_KEYWORD(ceildiv)
TOK_KEYWORD(dense)
+TOK_KEYWORD(else)
TOK_KEYWORD(splat)
TOK_KEYWORD(f16)
TOK_KEYWORD(f32)
@@ -99,6 +100,7 @@ TOK_KEYWORD(false)
TOK_KEYWORD(floordiv)
TOK_KEYWORD(for)
TOK_KEYWORD(func)
+TOK_KEYWORD(if)
TOK_KEYWORD(index)
TOK_KEYWORD(loc)
TOK_KEYWORD(max)
diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp
index afd18a49b79..c2e1636626d 100644
--- a/mlir/lib/Transforms/CSE.cpp
+++ b/mlir/lib/Transforms/CSE.cpp
@@ -188,6 +188,16 @@ void CSE::simplifyBlock(Block *bb) {
simplifyBlock(cast<ForInst>(i).getBody());
break;
}
+ case Instruction::Kind::If: {
+ auto &ifInst = cast<IfInst>(i);
+ if (auto *elseBlock = ifInst.getElse()) {
+ ScopedMapTy::ScopeTy scope(knownValues);
+ simplifyBlock(elseBlock);
+ }
+ ScopedMapTy::ScopeTy scope(knownValues);
+ simplifyBlock(ifInst.getThen());
+ break;
+ }
}
}
}
diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp
index eebbbe9daa7..cee0a08a63c 100644
--- a/mlir/lib/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Transforms/LoopFusion.cpp
@@ -19,7 +19,6 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/AffineOps/AffineOps.h"
#include "mlir/Analysis/AffineAnalysis.h"
#include "mlir/Analysis/AffineStructures.h"
#include "mlir/Analysis/LoopAnalysis.h"
@@ -100,16 +99,16 @@ public:
SmallVector<ForInst *, 4> forInsts;
SmallVector<OperationInst *, 4> loadOpInsts;
SmallVector<OperationInst *, 4> storeOpInsts;
- bool hasNonForRegion = false;
+ bool hasIfInst = false;
void visitForInst(ForInst *forInst) { forInsts.push_back(forInst); }
+ void visitIfInst(IfInst *ifInst) { hasIfInst = true; }
+
void visitOperationInst(OperationInst *opInst) {
- if (opInst->getNumBlockLists() != 0)
- hasNonForRegion = true;
- else if (opInst->isa<LoadOp>())
+ if (opInst->isa<LoadOp>())
loadOpInsts.push_back(opInst);
- else if (opInst->isa<StoreOp>())
+ if (opInst->isa<StoreOp>())
storeOpInsts.push_back(opInst);
}
};
@@ -411,8 +410,8 @@ bool MemRefDependenceGraph::init(Function *f) {
// all loads and store accesses it contains.
LoopNestStateCollector collector;
collector.walkForInst(forInst);
- // Return false if a non 'for' region was found (not currently supported).
- if (collector.hasNonForRegion)
+ // Return false if IfInsts are found (not currently supported).
+ if (collector.hasIfInst)
return false;
Node node(id++, &inst);
for (auto *opInst : collector.loadOpInsts) {
@@ -435,18 +434,19 @@ bool MemRefDependenceGraph::init(Function *f) {
auto *memref = opInst->cast<LoadOp>()->getMemRef();
memrefAccesses[memref].insert(node.id);
nodes.insert({node.id, node});
- } else if (auto storeOp = opInst->dyn_cast<StoreOp>()) {
+ }
+ if (auto storeOp = opInst->dyn_cast<StoreOp>()) {
// Create graph node for top-level store op.
Node node(id++, &inst);
node.stores.push_back(opInst);
auto *memref = opInst->cast<StoreOp>()->getMemRef();
memrefAccesses[memref].insert(node.id);
nodes.insert({node.id, node});
- } else if (opInst->getNumBlockLists() != 0) {
- // Return false if another region is found (not currently supported).
- return false;
}
}
+ // Return false if IfInsts are found (not currently supported).
+ if (isa<IfInst>(&inst))
+ return false;
}
// Walk memref access lists and add graph edges between dependent nodes.
diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp
index 6d63e4afd2d..39ef758833b 100644
--- a/mlir/lib/Transforms/LoopUnroll.cpp
+++ b/mlir/lib/Transforms/LoopUnroll.cpp
@@ -119,6 +119,15 @@ PassResult LoopUnroll::runOnFunction(Function *f) {
return true;
}
+ bool walkIfInstPostOrder(IfInst *ifInst) {
+ bool hasInnerLoops =
+ walkPostOrder(ifInst->getThen()->begin(), ifInst->getThen()->end());
+ if (ifInst->hasElse())
+ hasInnerLoops |=
+ walkPostOrder(ifInst->getElse()->begin(), ifInst->getElse()->end());
+ return hasInnerLoops;
+ }
+
bool walkOpInstPostOrder(OperationInst *opInst) {
for (auto &blockList : opInst->getBlockLists())
for (auto &block : blockList)
diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp
index f770684f519..ab37ff63bad 100644
--- a/mlir/lib/Transforms/LowerAffine.cpp
+++ b/mlir/lib/Transforms/LowerAffine.cpp
@@ -20,7 +20,6 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/AffineOps/AffineOps.h"
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
@@ -247,7 +246,7 @@ public:
PassResult runOnFunction(Function *function) override;
bool lowerForInst(ForInst *forInst);
- bool lowerAffineIf(AffineIfOp *ifOp);
+ bool lowerIfInst(IfInst *ifInst);
bool lowerAffineApply(AffineApplyOp *op);
static char passID;
@@ -410,7 +409,7 @@ bool LowerAffinePass::lowerForInst(ForInst *forInst) {
// enabling easy nesting of "if" instructions and if-then-else-if chains.
//
// +--------------------------------+
-// | <code before the AffineIfOp> |
+// | <code before the IfInst> |
// | %zero = constant 0 : index |
// | %v = affine_apply #expr1(%ops) |
// | %c = cmpi "sge" %v, %zero |
@@ -454,11 +453,10 @@ bool LowerAffinePass::lowerForInst(ForInst *forInst) {
// v v
// +--------------------------------+
// | continue: |
-// | <code after the AffineIfOp> |
+// | <code after the IfInst> |
// +--------------------------------+
//
-bool LowerAffinePass::lowerAffineIf(AffineIfOp *ifOp) {
- auto *ifInst = ifOp->getInstruction();
+bool LowerAffinePass::lowerIfInst(IfInst *ifInst) {
auto loc = ifInst->getLoc();
// Start by splitting the block containing the 'if' into two parts. The part
@@ -468,38 +466,22 @@ bool LowerAffinePass::lowerAffineIf(AffineIfOp *ifOp) {
auto *continueBlock = condBlock->splitBlock(ifInst);
// Create a block for the 'then' code, inserting it between the cond and
- // continue blocks. Move the instructions over from the AffineIfOp and add a
+ // continue blocks. Move the instructions over from the IfInst and add a
// branch to the continuation point.
Block *thenBlock = new Block();
thenBlock->insertBefore(continueBlock);
- // If the 'then' block is not empty, then splice the instructions.
- auto &oldThenBlocks = ifOp->getThenBlocks();
- if (!oldThenBlocks.empty()) {
- // We currently only handle one 'then' block.
- if (std::next(oldThenBlocks.begin()) != oldThenBlocks.end())
- return true;
-
- Block *oldThen = &oldThenBlocks.front();
-
- thenBlock->getInstructions().splice(thenBlock->begin(),
- oldThen->getInstructions(),
- oldThen->begin(), oldThen->end());
- }
-
+ auto *oldThen = ifInst->getThen();
+ thenBlock->getInstructions().splice(thenBlock->begin(),
+ oldThen->getInstructions(),
+ oldThen->begin(), oldThen->end());
FuncBuilder builder(thenBlock);
builder.create<BranchOp>(loc, continueBlock);
// Handle the 'else' block the same way, but we skip it if we have no else
// code.
Block *elseBlock = continueBlock;
- auto &oldElseBlocks = ifOp->getElseBlocks();
- if (!oldElseBlocks.empty()) {
- // We currently only handle one 'else' block.
- if (std::next(oldElseBlocks.begin()) != oldElseBlocks.end())
- return true;
-
- auto *oldElse = &oldElseBlocks.front();
+ if (auto *oldElse = ifInst->getElse()) {
elseBlock = new Block();
elseBlock->insertBefore(continueBlock);
@@ -511,7 +493,7 @@ bool LowerAffinePass::lowerAffineIf(AffineIfOp *ifOp) {
}
// Ok, now we just have to handle the condition logic.
- auto integerSet = ifOp->getIntegerSet();
+ auto integerSet = ifInst->getCondition().getIntegerSet();
// Implement short-circuit logic. For each affine expression in the 'if'
// condition, convert it into an affine map and call `affine_apply` to obtain
@@ -611,30 +593,29 @@ bool LowerAffinePass::lowerAffineApply(AffineApplyOp *op) {
PassResult LowerAffinePass::runOnFunction(Function *function) {
SmallVector<Instruction *, 8> instsToRewrite;
- // Collect all the For instructions as well as AffineIfOps and AffineApplyOps.
- // We do this as a prepass to avoid invalidating the walker with our rewrite.
+ // Collect all the If and For instructions as well as AffineApplyOps. We do
+ // this as a prepass to avoid invalidating the walker with our rewrite.
function->walkInsts([&](Instruction *inst) {
- if (isa<ForInst>(inst))
+ if (isa<IfInst>(inst) || isa<ForInst>(inst))
instsToRewrite.push_back(inst);
auto op = dyn_cast<OperationInst>(inst);
- if (op && (op->isa<AffineApplyOp>() || op->isa<AffineIfOp>()))
+ if (op && op->isa<AffineApplyOp>())
instsToRewrite.push_back(inst);
});
// Rewrite all of the ifs and fors. We walked the instructions in preorder,
// so we know that we will rewrite them in the same order.
for (auto *inst : instsToRewrite)
- if (auto *forInst = dyn_cast<ForInst>(inst)) {
+ if (auto *ifInst = dyn_cast<IfInst>(inst)) {
+ if (lowerIfInst(ifInst))
+ return failure();
+ } else if (auto *forInst = dyn_cast<ForInst>(inst)) {
if (lowerForInst(forInst))
return failure();
} else {
auto op = cast<OperationInst>(inst);
- if (auto ifOp = op->dyn_cast<AffineIfOp>()) {
- if (lowerAffineIf(ifOp))
- return failure();
- } else if (lowerAffineApply(op->cast<AffineApplyOp>())) {
+ if (lowerAffineApply(op->cast<AffineApplyOp>()))
return failure();
- }
}
return success();
diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp
index 2744b1d624c..09d961f85cd 100644
--- a/mlir/lib/Transforms/MaterializeVectors.cpp
+++ b/mlir/lib/Transforms/MaterializeVectors.cpp
@@ -20,7 +20,6 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/AffineOps/AffineOps.h"
#include "mlir/Analysis/AffineAnalysis.h"
#include "mlir/Analysis/Dominance.h"
#include "mlir/Analysis/LoopAnalysis.h"
@@ -560,6 +559,9 @@ static bool instantiateMaterialization(Instruction *inst,
if (isa<ForInst>(inst))
return inst->emitError("NYI path ForInst");
+ if (isa<IfInst>(inst))
+ return inst->emitError("NYI path IfInst");
+
// Create a builder here for unroll-and-jam effects.
FuncBuilder b(inst);
auto *opInst = cast<OperationInst>(inst);
@@ -568,9 +570,6 @@ static bool instantiateMaterialization(Instruction *inst,
if (opInst->isa<AffineApplyOp>()) {
return false;
}
- if (opInst->getNumBlockLists() != 0)
- return inst->emitError("NYI path Op with region");
-
if (auto write = opInst->dyn_cast<VectorTransferWriteOp>()) {
auto *clone = instantiate(&b, write, state->hwVectorType,
state->hwVectorInstance, state->substitutionsMap);
diff --git a/mlir/lib/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Transforms/SimplifyAffineStructures.cpp
index ba59123c700..bd39e47786a 100644
--- a/mlir/lib/Transforms/SimplifyAffineStructures.cpp
+++ b/mlir/lib/Transforms/SimplifyAffineStructures.cpp
@@ -28,6 +28,7 @@
#define DEBUG_TYPE "simplify-affine-structure"
using namespace mlir;
+using llvm::report_fatal_error;
namespace {
@@ -41,6 +42,9 @@ struct SimplifyAffineStructures : public FunctionPass {
PassResult runOnFunction(Function *f) override;
+ void visitIfInst(IfInst *ifInst);
+ void visitOperationInst(OperationInst *opInst);
+
static char passID;
};
@@ -62,19 +66,28 @@ static IntegerSet simplifyIntegerSet(IntegerSet set) {
return set;
}
-PassResult SimplifyAffineStructures::runOnFunction(Function *f) {
- f->walkOps([&](OperationInst *opInst) {
- for (auto attr : opInst->getAttrs()) {
- if (auto mapAttr = attr.second.dyn_cast<AffineMapAttr>()) {
- MutableAffineMap mMap(mapAttr.getValue());
- mMap.simplify();
- auto map = mMap.getAffineMap();
- opInst->setAttr(attr.first, AffineMapAttr::get(map));
- } else if (auto setAttr = attr.second.dyn_cast<IntegerSetAttr>()) {
- auto simplified = simplifyIntegerSet(setAttr.getValue());
- opInst->setAttr(attr.first, IntegerSetAttr::get(simplified));
- }
+void SimplifyAffineStructures::visitIfInst(IfInst *ifInst) {
+ auto set = ifInst->getCondition().getIntegerSet();
+ ifInst->setIntegerSet(simplifyIntegerSet(set));
+}
+
+void SimplifyAffineStructures::visitOperationInst(OperationInst *opInst) {
+ for (auto attr : opInst->getAttrs()) {
+ if (auto mapAttr = attr.second.dyn_cast<AffineMapAttr>()) {
+ MutableAffineMap mMap(mapAttr.getValue());
+ mMap.simplify();
+ auto map = mMap.getAffineMap();
+ opInst->setAttr(attr.first, AffineMapAttr::get(map));
}
+ }
+}
+
+PassResult SimplifyAffineStructures::runOnFunction(Function *f) {
+ f->walkInsts([&](Instruction *inst) {
+ if (auto *opInst = dyn_cast<OperationInst>(inst))
+ visitOperationInst(opInst);
+ if (auto *ifInst = dyn_cast<IfInst>(inst))
+ visitIfInst(ifInst);
});
return success();
diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir
index 595991c0109..bae112dd3b9 100644
--- a/mlir/test/IR/invalid.mlir
+++ b/mlir/test/IR/invalid.mlir
@@ -243,6 +243,14 @@ func @non_instruction() {
// -----
+func @invalid_if_conditional1() {
+ for %i = 1 to 10 {
+ if () { // expected-error {{expected ':' or '['}}
+ }
+}
+
+// -----
+
func @invalid_if_conditional2() {
for %i = 1 to 10 {
if (i)[N] : (i >= ) // expected-error {{expected '== 0' or '>= 0' at end of affine constraint}}
@@ -656,11 +664,7 @@ func @invalid_if_operands2(%N : index) {
func @invalid_if_operands3(%N : index) {
for %i = 1 to 10 {
if #set0(%i)[%i] {
- // expected-error@-1 {{operand cannot be used as a symbol}}
- }
- }
- return
-}
+ // expected-error@-1 {{value '%i' cannot be used as a symbol}}
// -----
// expected-error@+1 {{expected '"' in string literal}}
diff --git a/mlir/test/IR/locations.mlir b/mlir/test/IR/locations.mlir
index 8a90d12bd03..e3e1bbbbfad 100644
--- a/mlir/test/IR/locations.mlir
+++ b/mlir/test/IR/locations.mlir
@@ -16,9 +16,9 @@ func @inline_notation() -> i32 loc("mysource.cc":10:8) {
for %i0 = 0 to 8 loc(fused["foo", "mysource.cc":10:8]) {
}
- // CHECK: } loc(fused<"myPass">["foo", "foo2"])
- if #set0(%2) {
- } loc(fused<"myPass">["foo", "foo2"])
+ // CHECK: ) loc(fused<"myPass">["foo", "foo2"])
+ if #set0(%2) loc(fused<"myPass">["foo", "foo2"]) {
+ }
// CHECK: return %0 : i32 loc(unknown)
return %1 : i32 loc(unknown)
diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir
index 626f24569c6..33109606538 100644
--- a/mlir/test/IR/parser.mlir
+++ b/mlir/test/IR/parser.mlir
@@ -287,15 +287,13 @@ func @ifinst(%N: index) {
// CHECK: %c1_i32 = constant 1 : i32
%y = "add"(%x, %i) : (i32, index) -> i32 // CHECK: %0 = "add"(%c1_i32, %i0) : (i32, index) -> i32
%z = "mul"(%y, %y) : (i32, i32) -> i32 // CHECK: %1 = "mul"(%0, %0) : (i32, i32) -> i32
- } else { // CHECK } else {
- if (i)[N] : (i - 2 >= 0, 4 - i >= 0)(%i)[%N] { // CHECK if (#set1(%i0)[%arg0]) {
- // CHECK: %c1 = constant 1 : index
- %u = constant 1 : index
- // CHECK: %2 = affine_apply #map{{.*}}(%i0, %i0)[%c1]
- %w = affine_apply (d0,d1)[s0] -> (d0+d1+s0) (%i, %i) [%u]
- } else { // CHECK } else {
- %v = constant 3 : i32 // %c3_i32 = constant 3 : i32
- }
+ } else if (i)[N] : (i - 2 >= 0, 4 - i >= 0)(%i)[%N] { // CHECK } else if (#set1(%i0)[%arg0]) {
+ // CHECK: %c1 = constant 1 : index
+ %u = constant 1 : index
+ // CHECK: %2 = affine_apply #map{{.*}}(%i0, %i0)[%c1]
+ %w = affine_apply (d0,d1)[s0] -> (d0+d1+s0) (%i, %i) [%u]
+ } else { // CHECK } else {
+ %v = constant 3 : i32 // %c3_i32 = constant 3 : i32
} // CHECK }
} // CHECK }
return // CHECK return
@@ -753,11 +751,11 @@ func @type_alias() -> !i32_type_alias {
func @verbose_if(%N: index) {
%c = constant 200 : index
- // CHECK: if #set0(%c200)[%arg0, %c200] {
- "if"(%c, %N, %c) { condition: #set0 } : (index, index, index) -> () {
+ // CHECK: "if"(%c200, %arg0, %c200) {cond: #set0} : (index, index, index) -> () {
+ "if"(%c, %N, %c) { cond: #set0 } : (index, index, index) -> () {
// CHECK-NEXT: "add"
%y = "add"(%c, %N) : (index, index) -> index
- // CHECK-NEXT: } else {
+ // CHECK-NEXT: } {
} { // The else block list.
// CHECK-NEXT: "add"
%z = "add"(%c, %c) : (index, index) -> index
diff --git a/mlir/test/IR/pretty-locations.mlir b/mlir/test/IR/pretty-locations.mlir
index 69dace45165..cb2e14a56d5 100644
--- a/mlir/test/IR/pretty-locations.mlir
+++ b/mlir/test/IR/pretty-locations.mlir
@@ -21,10 +21,10 @@ func @inline_notation() -> i32 loc("mysource.cc":10:8) {
for %i0 = 0 to 8 loc(fused["foo", "mysource.cc":10:8]) {
}
- // CHECK: } <"myPass">["foo", "foo2"]
- if #set0(%2) {
- } loc(fused<"myPass">["foo", "foo2"])
+ // CHECK: ) <"myPass">["foo", "foo2"]
+ if #set0(%2) loc(fused<"myPass">["foo", "foo2"]) {
+ }
// CHECK: return %0 : i32 [unknown]
return %1 : i32 loc(unknown)
-}
+} \ No newline at end of file
diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir
index 162f193f662..d170ce590f7 100644
--- a/mlir/test/Transforms/loop-fusion.mlir
+++ b/mlir/test/Transforms/loop-fusion.mlir
@@ -483,7 +483,7 @@ func @should_not_fuse_if_inst_at_top_level() {
%c0 = constant 4 : index
if #set0(%c0) {
}
- // Top-level IfOp should prevent fusion.
+ // Top-level IfInst should prevent fusion.
// CHECK: for %i0 = 0 to 10 {
// CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32>
// CHECK-NEXT: }
@@ -512,7 +512,7 @@ func @should_not_fuse_if_inst_in_loop_nest() {
%v0 = load %m[%i1] : memref<10xf32>
}
- // IfOp in ForInst should prevent fusion.
+ // IfInst in ForInst should prevent fusion.
// CHECK: for %i0 = 0 to 10 {
// CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32>
// CHECK-NEXT: }
diff --git a/mlir/test/Transforms/memref-dependence-check.mlir b/mlir/test/Transforms/memref-dependence-check.mlir
index 628044ed77a..6f6ad3fafc7 100644
--- a/mlir/test/Transforms/memref-dependence-check.mlir
+++ b/mlir/test/Transforms/memref-dependence-check.mlir
@@ -10,7 +10,7 @@ func @store_may_execute_before_load() {
%cf7 = constant 7.0 : f32
%c0 = constant 4 : index
// There is a dependence from store 0 to load 1 at depth 1 because the
- // ancestor IfOp of the store, dominates the ancestor ForSmt of the load,
+ // ancestor IfInst of the store, dominates the ancestor ForSmt of the load,
// and thus the store "may" conditionally execute before the load.
if #set0(%c0) {
for %i0 = 0 to 10 {
diff --git a/mlir/test/Transforms/strip-debug-info.mlir b/mlir/test/Transforms/strip-debug-info.mlir
index 13f009deb70..5509c7aba55 100644
--- a/mlir/test/Transforms/strip-debug-info.mlir
+++ b/mlir/test/Transforms/strip-debug-info.mlir
@@ -13,10 +13,10 @@ func @inline_notation() -> i32 loc("mysource.cc":10:8) {
for %i0 = 0 to 8 loc(fused["foo", "mysource.cc":10:8]) {
}
- // CHECK: } loc(unknown)
+ // CHECK: if #set0(%c4) loc(unknown)
%2 = constant 4 : index
- if #set0(%2) {
- } loc(fused<"myPass">["foo", "foo2"])
+ if #set0(%2) loc(fused<"myPass">["foo", "foo2"]) {
+ }
// CHECK: return %0 : i32 loc(unknown)
return %1 : i32 loc("bar")
OpenPOWER on IntegriCloud