diff options
| author | Aart Bik <ajcbik@google.com> | 2019-11-26 19:52:02 -0800 |
|---|---|---|
| committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-11-26 19:52:38 -0800 |
| commit | e2232fbcee8a4bf4e2a6ab181f8fabb57633dda6 (patch) | |
| tree | 826c3e2f33dab6d38433ac7fa6ec6f9d81c6d72c | |
| parent | f27ceb726188d0b16c979fddf644e33886139006 (diff) | |
| download | bcm5719-llvm-e2232fbcee8a4bf4e2a6ab181f8fabb57633dda6.tar.gz bcm5719-llvm-e2232fbcee8a4bf4e2a6ab181f8fabb57633dda6.zip | |
[VectorOps] Refine BroadcastOp in VectorOps dialect
Since second argument is always fully overwritten and
shape is define in "to" clause, it is not needed.
Also renamed "into" to "to" now that arg is dropped.
PiperOrigin-RevId: 282686475
| -rw-r--r-- | mlir/include/mlir/Dialect/VectorOps/VectorOps.td | 18 | ||||
| -rw-r--r-- | mlir/lib/Dialect/VectorOps/VectorOps.cpp | 18 | ||||
| -rw-r--r-- | mlir/test/Dialect/VectorOps/invalid.mlir | 4 | ||||
| -rw-r--r-- | mlir/test/Dialect/VectorOps/ops.mlir | 12 |
4 files changed, 24 insertions, 28 deletions
diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td index 34c2fa97e53..c78334dd54a 100644 --- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td +++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td @@ -165,27 +165,25 @@ def Vector_ContractionOp : def Vector_BroadcastOp : Vector_Op<"broadcast", [NoSideEffect, PredOpTrait<"source operand and result have same element type", - TCresVTEtIsSameAsOpBase<0, 0>>, - PredOpTrait<"dest operand and result have same type", - TCresIsSameAsOpBase<0, 1>>]>, - Arguments<(ins AnyType:$source, AnyVector:$dest)>, + TCresVTEtIsSameAsOpBase<0, 0>>]>, + Arguments<(ins AnyType:$source)>, Results<(outs AnyVector:$vector)> { let summary = "broadcast operation"; let description = [{ - Broadcasts the scalar or k-D vector value in the source to the n-D - destination vector of a proper shape such that the broadcast makes sense. + Broadcasts the scalar or k-D vector value in the source operand + to a n-D result vector such that the broadcast makes sense. Examples: ``` %0 = constant 0.0 : f32 - %1 = vector.broadcast %0, %x : f32 into vector<16xf32> - %2 = vector.broadcast %1, %y : vector<16xf32> into vector<4x16xf32> + %1 = vector.broadcast %0 : f32 to vector<16xf32> + %2 = vector.broadcast %1 : vector<16xf32> to vector<4x16xf32> ``` }]; let extraClassDeclaration = [{ Type getSourceType() { return source()->getType(); } - VectorType getDestVectorType() { - return dest()->getType().cast<VectorType>(); + VectorType getVectorType() { + return vector()->getType().cast<VectorType>(); } }]; } diff --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp index d09fd0fc2f2..fe320b91439 100644 --- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp @@ -373,14 +373,14 @@ static LogicalResult verify(ExtractElementOp op) { //===----------------------------------------------------------------------===// static void print(OpAsmPrinter &p, BroadcastOp op) { - p << op.getOperationName() << " " << *op.source() << ", " << *op.dest(); + p << op.getOperationName() << " " << *op.source(); p << " : " << op.getSourceType(); - p << " into " << op.getDestVectorType(); + p << " to " << op.getVectorType(); } static LogicalResult verify(BroadcastOp op) { VectorType srcVectorType = op.getSourceType().dyn_cast<VectorType>(); - VectorType dstVectorType = op.getDestVectorType(); + VectorType dstVectorType = op.getVectorType(); // Scalar to vector broadcast is always valid. A vector // to vector broadcast needs some additional checking. if (srcVectorType) { @@ -397,16 +397,14 @@ static LogicalResult verify(BroadcastOp op) { static ParseResult parseBroadcastOp(OpAsmParser &parser, OperationState &result) { - OpAsmParser::OperandType source, dest; + OpAsmParser::OperandType source; Type sourceType; - VectorType destType; - return failure(parser.parseOperand(source) || parser.parseComma() || - parser.parseOperand(dest) || + VectorType vectorType; + return failure(parser.parseOperand(source) || parser.parseColonType(sourceType) || - parser.parseKeywordType("into", destType) || + parser.parseKeywordType("to", vectorType) || parser.resolveOperand(source, sourceType, result.operands) || - parser.resolveOperand(dest, destType, result.operands) || - parser.addTypeToList(destType, result.types)); + parser.addTypeToList(vectorType, result.types)); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/VectorOps/invalid.mlir b/mlir/test/Dialect/VectorOps/invalid.mlir index 92e956ef29a..d672b1bf140 100644 --- a/mlir/test/Dialect/VectorOps/invalid.mlir +++ b/mlir/test/Dialect/VectorOps/invalid.mlir @@ -2,9 +2,9 @@ // ----- -func @broadcast_rank_too_high(%arg0: vector<4x4xf32>, %arg1: vector<4xf32>) { +func @broadcast_rank_too_high(%arg0: vector<4x4xf32>) { // expected-error@+1 {{source rank higher than destination rank}} - %2 = vector.broadcast %arg0, %arg1 : vector<4x4xf32> into vector<4xf32> + %1 = vector.broadcast %arg0 : vector<4x4xf32> to vector<4xf32> } // ----- diff --git a/mlir/test/Dialect/VectorOps/ops.mlir b/mlir/test/Dialect/VectorOps/ops.mlir index 51dbc4f0435..d167559ac0c 100644 --- a/mlir/test/Dialect/VectorOps/ops.mlir +++ b/mlir/test/Dialect/VectorOps/ops.mlir @@ -23,12 +23,12 @@ func @vector_transfer_ops(%arg0: memref<?x?xf32>) { } // CHECK-LABEL: @vector_broadcast -func @vector_broadcast(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>) { - // CHECK: vector.broadcast %{{.*}}, %{{.*}} : f32 into vector<16xf32> - %0 = vector.broadcast %a, %b : f32 into vector<16xf32> - // CHECK-NEXT: vector.broadcast %{{.*}}, %{{.*}} : vector<16xf32> into vector<8x16xf32> - %1 = vector.broadcast %b, %c : vector<16xf32> into vector<8x16xf32> - return +func @vector_broadcast(%a: f32, %b: vector<16xf32>) -> vector<8x16xf32> { + // CHECK: vector.broadcast %{{.*}} : f32 to vector<16xf32> + %0 = vector.broadcast %a : f32 to vector<16xf32> + // CHECK-NEXT: vector.broadcast %{{.*}} : vector<16xf32> to vector<8x16xf32> + %1 = vector.broadcast %b : vector<16xf32> to vector<8x16xf32> + return %1 : vector<8x16xf32> } // CHECK-LABEL: @extractelement |

