summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Dialect/LoopOps/LoopOps.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/LoopOps/LoopOps.cpp')
-rw-r--r--mlir/lib/Dialect/LoopOps/LoopOps.cpp197
1 files changed, 195 insertions, 2 deletions
diff --git a/mlir/lib/Dialect/LoopOps/LoopOps.cpp b/mlir/lib/Dialect/LoopOps/LoopOps.cpp
index 5452b3d4ab8..4824421d190 100644
--- a/mlir/lib/Dialect/LoopOps/LoopOps.cpp
+++ b/mlir/lib/Dialect/LoopOps/LoopOps.cpp
@@ -185,13 +185,13 @@ static ParseResult parseIfOp(OpAsmParser &parser, OperationState &result) {
return failure();
// Parse the 'then' region.
- if (parser.parseRegion(*thenRegion, {}, {}))
+ if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
return failure();
IfOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location);
// If we find an 'else' keyword then parse the 'else' region.
if (!parser.parseOptionalKeyword("else")) {
- if (parser.parseRegion(*elseRegion, {}, {}))
+ if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
return failure();
IfOp::ensureTerminator(*elseRegion, parser.getBuilder(), result.location);
}
@@ -222,6 +222,199 @@ static void print(OpAsmPrinter &p, IfOp op) {
}
//===----------------------------------------------------------------------===//
+// ParallelOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(ParallelOp op) {
+ // Check that there is at least one value in lowerBound, upperBound and step.
+ // It is sufficient to test only step, because it is ensured already that the
+ // number of elements in lowerBound, upperBound and step are the same.
+ Operation::operand_range stepValues = op.step();
+ if (stepValues.empty())
+ return op.emitOpError(
+ "needs at least one tuple element for lowerBound, upperBound and step");
+
+ // Check whether all constant step values are positive.
+ for (Value stepValue : stepValues)
+ if (auto cst = dyn_cast_or_null<ConstantIndexOp>(stepValue.getDefiningOp()))
+ if (cst.getValue() <= 0)
+ return op.emitOpError("constant step operand must be positive");
+
+ // Check that the body defines the same number of block arguments as the
+ // number of tuple elements in step.
+ Block *body = &op.body().front();
+ if (body->getNumArguments() != stepValues.size())
+ return op.emitOpError(
+ "expects the same number of induction variables as bound and step "
+ "values");
+ for (auto arg : body->getArguments())
+ if (!arg.getType().isIndex())
+ return op.emitOpError(
+ "expects arguments for the induction variable to be of index type");
+
+ // Check that the number of results is the same as the number of ReduceOps.
+ SmallVector<ReduceOp, 4> reductions(body->getOps<ReduceOp>());
+ if (op.results().size() != reductions.size())
+ return op.emitOpError(
+ "expects number of results to be the same as number of reductions");
+
+ // Check that the types of the results and reductions are the same.
+ for (auto resultAndReduce : llvm::zip(op.results(), reductions)) {
+ auto resultType = std::get<0>(resultAndReduce).getType();
+ auto reduceOp = std::get<1>(resultAndReduce);
+ auto reduceType = reduceOp.operand().getType();
+ if (resultType != reduceType)
+ return reduceOp.emitOpError()
+ << "expects type of reduce to be the same as result type: "
+ << resultType;
+ }
+ return success();
+}
+
+static ParseResult parseParallelOp(OpAsmParser &parser,
+ OperationState &result) {
+ auto &builder = parser.getBuilder();
+ // Parse an opening `(` followed by induction variables followed by `)`
+ SmallVector<OpAsmParser::OperandType, 4> ivs;
+ if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1,
+ OpAsmParser::Delimiter::Paren))
+ return failure();
+
+ // Parse loop bounds.
+ SmallVector<OpAsmParser::OperandType, 4> lower;
+ if (parser.parseEqual() ||
+ parser.parseOperandList(lower, ivs.size(),
+ OpAsmParser::Delimiter::Paren) ||
+ parser.resolveOperands(lower, builder.getIndexType(), result.operands))
+ return failure();
+
+ SmallVector<OpAsmParser::OperandType, 4> upper;
+ if (parser.parseKeyword("to") ||
+ parser.parseOperandList(upper, ivs.size(),
+ OpAsmParser::Delimiter::Paren) ||
+ parser.resolveOperands(upper, builder.getIndexType(), result.operands))
+ return failure();
+
+ // Parse step value.
+ SmallVector<OpAsmParser::OperandType, 4> steps;
+ if (parser.parseKeyword("step") ||
+ parser.parseOperandList(steps, ivs.size(),
+ OpAsmParser::Delimiter::Paren) ||
+ parser.resolveOperands(steps, builder.getIndexType(), result.operands))
+ return failure();
+
+ // Now parse the body.
+ Region *body = result.addRegion();
+ SmallVector<Type, 4> types(ivs.size(), builder.getIndexType());
+ if (parser.parseRegion(*body, ivs, types))
+ return failure();
+
+ // Parse attributes and optional results (in case there is a reduce).
+ if (parser.parseOptionalAttrDict(result.attributes) ||
+ parser.parseOptionalColonTypeList(result.types))
+ return failure();
+
+ // Add a terminator if none was parsed.
+ ForOp::ensureTerminator(*body, builder, result.location);
+
+ return success();
+}
+
+static void print(OpAsmPrinter &p, ParallelOp op) {
+ p << op.getOperationName() << " (";
+ p.printOperands(op.body().front().getArguments());
+ p << ") = (" << op.lowerBound() << ") to (" << op.upperBound() << ") step ("
+ << op.step() << ")";
+ p.printRegion(op.body(), /*printEntryBlockArgs=*/false);
+ p.printOptionalAttrDict(op.getAttrs());
+ if (!op.results().empty())
+ p << " : " << op.getResultTypes();
+}
+
+//===----------------------------------------------------------------------===//
+// ReduceOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(ReduceOp op) {
+ // The region of a ReduceOp has two arguments of the same type as its operand.
+ auto type = op.operand().getType();
+ Block &block = op.reductionOperator().front();
+ if (block.empty())
+ return op.emitOpError("the block inside reduce should not be empty");
+ if (block.getNumArguments() != 2 ||
+ llvm::any_of(block.getArguments(), [&](const BlockArgument &arg) {
+ return arg.getType() != type;
+ }))
+ return op.emitOpError() << "expects two arguments to reduce block of type "
+ << type;
+
+ // Check that the block is terminated by a ReduceReturnOp.
+ if (!isa<ReduceReturnOp>(block.getTerminator()))
+ return op.emitOpError("the block inside reduce should be terminated with a "
+ "'loop.reduce.return' op");
+
+ return success();
+}
+
+static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) {
+ // Parse an opening `(` followed by the reduced value followed by `)`
+ OpAsmParser::OperandType operand;
+ if (parser.parseLParen() || parser.parseOperand(operand) ||
+ parser.parseRParen())
+ return failure();
+
+ // Now parse the body.
+ Region *body = result.addRegion();
+ if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
+ return failure();
+
+ // And the type of the operand (and also what reduce computes on).
+ Type resultType;
+ if (parser.parseColonType(resultType) ||
+ parser.resolveOperand(operand, resultType, result.operands))
+ return failure();
+
+ return success();
+}
+
+static void print(OpAsmPrinter &p, ReduceOp op) {
+ p << op.getOperationName() << "(" << op.operand() << ") ";
+ p.printRegion(op.reductionOperator());
+ p << " : " << op.operand().getType();
+}
+
+//===----------------------------------------------------------------------===//
+// ReduceReturnOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(ReduceReturnOp op) {
+ // The type of the return value should be the same type as the type of the
+ // operand of the enclosing ReduceOp.
+ auto reduceOp = cast<ReduceOp>(op.getParentOp());
+ Type reduceType = reduceOp.operand().getType();
+ if (reduceType != op.result().getType())
+ return op.emitOpError() << "needs to have type " << reduceType
+ << " (the type of the enclosing ReduceOp)";
+ return success();
+}
+
+static ParseResult parseReduceReturnOp(OpAsmParser &parser,
+ OperationState &result) {
+ OpAsmParser::OperandType operand;
+ Type resultType;
+ if (parser.parseOperand(operand) || parser.parseColonType(resultType) ||
+ parser.resolveOperand(operand, resultType, result.operands))
+ return failure();
+
+ return success();
+}
+
+static void print(OpAsmPrinter &p, ReduceReturnOp op) {
+ p << op.getOperationName() << " " << op.result() << " : "
+ << op.result().getType();
+}
+
+//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
OpenPOWER on IntegriCloud