summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Dialect/VectorOps/VectorOps.cpp
diff options
context:
space:
mode:
authorAart Bik <ajcbik@google.com>2019-11-26 14:43:03 -0800
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-11-26 14:43:31 -0800
commitcf97263cb8cd4f6f21f00eabb9d6a007e221eaab (patch)
treeb313b50e227d67d8c88e71f01398ff03c662efb4 /mlir/lib/Dialect/VectorOps/VectorOps.cpp
parent18aec3e2e5b6ded304ff2a12f807bc2ace1b3a4b (diff)
downloadbcm5719-llvm-cf97263cb8cd4f6f21f00eabb9d6a007e221eaab.tar.gz
bcm5719-llvm-cf97263cb8cd4f6f21f00eabb9d6a007e221eaab.zip
[VectorOps] Add a BroadcastOp to the VectorOps dialect
PiperOrigin-RevId: 282643305
Diffstat (limited to 'mlir/lib/Dialect/VectorOps/VectorOps.cpp')
-rw-r--r--mlir/lib/Dialect/VectorOps/VectorOps.cpp41
1 files changed, 41 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp
index b73b771d80d..d09fd0fc2f2 100644
--- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp
+++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp
@@ -369,6 +369,47 @@ static LogicalResult verify(ExtractElementOp op) {
}
//===----------------------------------------------------------------------===//
+// BroadcastOp
+//===----------------------------------------------------------------------===//
+
+static void print(OpAsmPrinter &p, BroadcastOp op) {
+ p << op.getOperationName() << " " << *op.source() << ", " << *op.dest();
+ p << " : " << op.getSourceType();
+ p << " into " << op.getDestVectorType();
+}
+
+static LogicalResult verify(BroadcastOp op) {
+ VectorType srcVectorType = op.getSourceType().dyn_cast<VectorType>();
+ VectorType dstVectorType = op.getDestVectorType();
+ // Scalar to vector broadcast is always valid. A vector
+ // to vector broadcast needs some additional checking.
+ if (srcVectorType) {
+ const int64_t srcRank = srcVectorType.getRank();
+ const int64_t dstRank = dstVectorType.getRank();
+ // TODO(ajcbik): implement proper rank testing for broadcast;
+ // this is just a temporary placeholder check.
+ if (srcRank > dstRank) {
+ return op.emitOpError("source rank higher than destination rank");
+ }
+ }
+ return success();
+}
+
+static ParseResult parseBroadcastOp(OpAsmParser &parser,
+ OperationState &result) {
+ OpAsmParser::OperandType source, dest;
+ Type sourceType;
+ VectorType destType;
+ return failure(parser.parseOperand(source) || parser.parseComma() ||
+ parser.parseOperand(dest) ||
+ parser.parseColonType(sourceType) ||
+ parser.parseKeywordType("into", destType) ||
+ parser.resolveOperand(source, sourceType, result.operands) ||
+ parser.resolveOperand(dest, destType, result.operands) ||
+ parser.addTypeToList(destType, result.types));
+}
+
+//===----------------------------------------------------------------------===//
// InsertElementOp
//===----------------------------------------------------------------------===//
OpenPOWER on IntegriCloud