From cf97263cb8cd4f6f21f00eabb9d6a007e221eaab Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Tue, 26 Nov 2019 14:43:03 -0800 Subject: [VectorOps] Add a BroadcastOp to the VectorOps dialect PiperOrigin-RevId: 282643305 --- mlir/lib/Dialect/VectorOps/VectorOps.cpp | 41 ++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) (limited to 'mlir/lib/Dialect/VectorOps/VectorOps.cpp') diff --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp index b73b771d80d..d09fd0fc2f2 100644 --- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp @@ -368,6 +368,47 @@ static LogicalResult verify(ExtractElementOp op) { return success(); } +//===----------------------------------------------------------------------===// +// BroadcastOp +//===----------------------------------------------------------------------===// + +static void print(OpAsmPrinter &p, BroadcastOp op) { + p << op.getOperationName() << " " << *op.source() << ", " << *op.dest(); + p << " : " << op.getSourceType(); + p << " into " << op.getDestVectorType(); +} + +static LogicalResult verify(BroadcastOp op) { + VectorType srcVectorType = op.getSourceType().dyn_cast(); + VectorType dstVectorType = op.getDestVectorType(); + // Scalar to vector broadcast is always valid. A vector + // to vector broadcast needs some additional checking. + if (srcVectorType) { + const int64_t srcRank = srcVectorType.getRank(); + const int64_t dstRank = dstVectorType.getRank(); + // TODO(ajcbik): implement proper rank testing for broadcast; + // this is just a temporary placeholder check. + if (srcRank > dstRank) { + return op.emitOpError("source rank higher than destination rank"); + } + } + return success(); +} + +static ParseResult parseBroadcastOp(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::OperandType source, dest; + Type sourceType; + VectorType destType; + return failure(parser.parseOperand(source) || parser.parseComma() || + parser.parseOperand(dest) || + parser.parseColonType(sourceType) || + parser.parseKeywordType("into", destType) || + parser.resolveOperand(source, sourceType, result.operands) || + parser.resolveOperand(dest, destType, result.operands) || + parser.addTypeToList(destType, result.types)); +} + //===----------------------------------------------------------------------===// // InsertElementOp //===----------------------------------------------------------------------===// -- cgit v1.2.3