diff options
| author | Aart Bik <ajcbik@google.com> | 2019-11-26 14:43:03 -0800 |
|---|---|---|
| committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-11-26 14:43:31 -0800 |
| commit | cf97263cb8cd4f6f21f00eabb9d6a007e221eaab (patch) | |
| tree | b313b50e227d67d8c88e71f01398ff03c662efb4 /mlir/lib/Dialect/VectorOps/VectorOps.cpp | |
| parent | 18aec3e2e5b6ded304ff2a12f807bc2ace1b3a4b (diff) | |
| download | bcm5719-llvm-cf97263cb8cd4f6f21f00eabb9d6a007e221eaab.tar.gz bcm5719-llvm-cf97263cb8cd4f6f21f00eabb9d6a007e221eaab.zip | |
[VectorOps] Add a BroadcastOp to the VectorOps dialect
PiperOrigin-RevId: 282643305
Diffstat (limited to 'mlir/lib/Dialect/VectorOps/VectorOps.cpp')
| -rw-r--r-- | mlir/lib/Dialect/VectorOps/VectorOps.cpp | 41 |
1 files changed, 41 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp index b73b771d80d..d09fd0fc2f2 100644 --- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp @@ -369,6 +369,47 @@ static LogicalResult verify(ExtractElementOp op) { } //===----------------------------------------------------------------------===// +// BroadcastOp +//===----------------------------------------------------------------------===// + +static void print(OpAsmPrinter &p, BroadcastOp op) { + p << op.getOperationName() << " " << *op.source() << ", " << *op.dest(); + p << " : " << op.getSourceType(); + p << " into " << op.getDestVectorType(); +} + +static LogicalResult verify(BroadcastOp op) { + VectorType srcVectorType = op.getSourceType().dyn_cast<VectorType>(); + VectorType dstVectorType = op.getDestVectorType(); + // Scalar to vector broadcast is always valid. A vector + // to vector broadcast needs some additional checking. + if (srcVectorType) { + const int64_t srcRank = srcVectorType.getRank(); + const int64_t dstRank = dstVectorType.getRank(); + // TODO(ajcbik): implement proper rank testing for broadcast; + // this is just a temporary placeholder check. + if (srcRank > dstRank) { + return op.emitOpError("source rank higher than destination rank"); + } + } + return success(); +} + +static ParseResult parseBroadcastOp(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::OperandType source, dest; + Type sourceType; + VectorType destType; + return failure(parser.parseOperand(source) || parser.parseComma() || + parser.parseOperand(dest) || + parser.parseColonType(sourceType) || + parser.parseKeywordType("into", destType) || + parser.resolveOperand(source, sourceType, result.operands) || + parser.resolveOperand(dest, destType, result.operands) || + parser.addTypeToList(destType, result.types)); +} + +//===----------------------------------------------------------------------===// // InsertElementOp //===----------------------------------------------------------------------===// |

