diff options
Diffstat (limited to 'mlir/lib/Dialect/LoopOps/LoopOps.cpp')
-rw-r--r-- | mlir/lib/Dialect/LoopOps/LoopOps.cpp | 197 |
1 files changed, 195 insertions, 2 deletions
diff --git a/mlir/lib/Dialect/LoopOps/LoopOps.cpp b/mlir/lib/Dialect/LoopOps/LoopOps.cpp index 5452b3d4ab8..4824421d190 100644 --- a/mlir/lib/Dialect/LoopOps/LoopOps.cpp +++ b/mlir/lib/Dialect/LoopOps/LoopOps.cpp @@ -185,13 +185,13 @@ static ParseResult parseIfOp(OpAsmParser &parser, OperationState &result) { return failure(); // Parse the 'then' region. - if (parser.parseRegion(*thenRegion, {}, {})) + if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{})) return failure(); IfOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location); // If we find an 'else' keyword then parse the 'else' region. if (!parser.parseOptionalKeyword("else")) { - if (parser.parseRegion(*elseRegion, {}, {})) + if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{})) return failure(); IfOp::ensureTerminator(*elseRegion, parser.getBuilder(), result.location); } @@ -222,6 +222,199 @@ static void print(OpAsmPrinter &p, IfOp op) { } //===----------------------------------------------------------------------===// +// ParallelOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(ParallelOp op) { + // Check that there is at least one value in lowerBound, upperBound and step. + // It is sufficient to test only step, because it is ensured already that the + // number of elements in lowerBound, upperBound and step are the same. + Operation::operand_range stepValues = op.step(); + if (stepValues.empty()) + return op.emitOpError( + "needs at least one tuple element for lowerBound, upperBound and step"); + + // Check whether all constant step values are positive. + for (Value stepValue : stepValues) + if (auto cst = dyn_cast_or_null<ConstantIndexOp>(stepValue.getDefiningOp())) + if (cst.getValue() <= 0) + return op.emitOpError("constant step operand must be positive"); + + // Check that the body defines the same number of block arguments as the + // number of tuple elements in step. + Block *body = &op.body().front(); + if (body->getNumArguments() != stepValues.size()) + return op.emitOpError( + "expects the same number of induction variables as bound and step " + "values"); + for (auto arg : body->getArguments()) + if (!arg.getType().isIndex()) + return op.emitOpError( + "expects arguments for the induction variable to be of index type"); + + // Check that the number of results is the same as the number of ReduceOps. + SmallVector<ReduceOp, 4> reductions(body->getOps<ReduceOp>()); + if (op.results().size() != reductions.size()) + return op.emitOpError( + "expects number of results to be the same as number of reductions"); + + // Check that the types of the results and reductions are the same. + for (auto resultAndReduce : llvm::zip(op.results(), reductions)) { + auto resultType = std::get<0>(resultAndReduce).getType(); + auto reduceOp = std::get<1>(resultAndReduce); + auto reduceType = reduceOp.operand().getType(); + if (resultType != reduceType) + return reduceOp.emitOpError() + << "expects type of reduce to be the same as result type: " + << resultType; + } + return success(); +} + +static ParseResult parseParallelOp(OpAsmParser &parser, + OperationState &result) { + auto &builder = parser.getBuilder(); + // Parse an opening `(` followed by induction variables followed by `)` + SmallVector<OpAsmParser::OperandType, 4> ivs; + if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1, + OpAsmParser::Delimiter::Paren)) + return failure(); + + // Parse loop bounds. + SmallVector<OpAsmParser::OperandType, 4> lower; + if (parser.parseEqual() || + parser.parseOperandList(lower, ivs.size(), + OpAsmParser::Delimiter::Paren) || + parser.resolveOperands(lower, builder.getIndexType(), result.operands)) + return failure(); + + SmallVector<OpAsmParser::OperandType, 4> upper; + if (parser.parseKeyword("to") || + parser.parseOperandList(upper, ivs.size(), + OpAsmParser::Delimiter::Paren) || + parser.resolveOperands(upper, builder.getIndexType(), result.operands)) + return failure(); + + // Parse step value. + SmallVector<OpAsmParser::OperandType, 4> steps; + if (parser.parseKeyword("step") || + parser.parseOperandList(steps, ivs.size(), + OpAsmParser::Delimiter::Paren) || + parser.resolveOperands(steps, builder.getIndexType(), result.operands)) + return failure(); + + // Now parse the body. + Region *body = result.addRegion(); + SmallVector<Type, 4> types(ivs.size(), builder.getIndexType()); + if (parser.parseRegion(*body, ivs, types)) + return failure(); + + // Parse attributes and optional results (in case there is a reduce). + if (parser.parseOptionalAttrDict(result.attributes) || + parser.parseOptionalColonTypeList(result.types)) + return failure(); + + // Add a terminator if none was parsed. + ForOp::ensureTerminator(*body, builder, result.location); + + return success(); +} + +static void print(OpAsmPrinter &p, ParallelOp op) { + p << op.getOperationName() << " ("; + p.printOperands(op.body().front().getArguments()); + p << ") = (" << op.lowerBound() << ") to (" << op.upperBound() << ") step (" + << op.step() << ")"; + p.printRegion(op.body(), /*printEntryBlockArgs=*/false); + p.printOptionalAttrDict(op.getAttrs()); + if (!op.results().empty()) + p << " : " << op.getResultTypes(); +} + +//===----------------------------------------------------------------------===// +// ReduceOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(ReduceOp op) { + // The region of a ReduceOp has two arguments of the same type as its operand. + auto type = op.operand().getType(); + Block &block = op.reductionOperator().front(); + if (block.empty()) + return op.emitOpError("the block inside reduce should not be empty"); + if (block.getNumArguments() != 2 || + llvm::any_of(block.getArguments(), [&](const BlockArgument &arg) { + return arg.getType() != type; + })) + return op.emitOpError() << "expects two arguments to reduce block of type " + << type; + + // Check that the block is terminated by a ReduceReturnOp. + if (!isa<ReduceReturnOp>(block.getTerminator())) + return op.emitOpError("the block inside reduce should be terminated with a " + "'loop.reduce.return' op"); + + return success(); +} + +static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) { + // Parse an opening `(` followed by the reduced value followed by `)` + OpAsmParser::OperandType operand; + if (parser.parseLParen() || parser.parseOperand(operand) || + parser.parseRParen()) + return failure(); + + // Now parse the body. + Region *body = result.addRegion(); + if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{})) + return failure(); + + // And the type of the operand (and also what reduce computes on). + Type resultType; + if (parser.parseColonType(resultType) || + parser.resolveOperand(operand, resultType, result.operands)) + return failure(); + + return success(); +} + +static void print(OpAsmPrinter &p, ReduceOp op) { + p << op.getOperationName() << "(" << op.operand() << ") "; + p.printRegion(op.reductionOperator()); + p << " : " << op.operand().getType(); +} + +//===----------------------------------------------------------------------===// +// ReduceReturnOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(ReduceReturnOp op) { + // The type of the return value should be the same type as the type of the + // operand of the enclosing ReduceOp. + auto reduceOp = cast<ReduceOp>(op.getParentOp()); + Type reduceType = reduceOp.operand().getType(); + if (reduceType != op.result().getType()) + return op.emitOpError() << "needs to have type " << reduceType + << " (the type of the enclosing ReduceOp)"; + return success(); +} + +static ParseResult parseReduceReturnOp(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::OperandType operand; + Type resultType; + if (parser.parseOperand(operand) || parser.parseColonType(resultType) || + parser.resolveOperand(operand, resultType, result.operands)) + return failure(); + + return success(); +} + +static void print(OpAsmPrinter &p, ReduceReturnOp op) { + p << op.getOperationName() << " " << op.result() << " : " + << op.result().getType(); +} + +//===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// |