diff options
| -rw-r--r-- | mlir/include/mlir/Dialect/VectorOps/VectorOps.td | 18 | ||||
| -rw-r--r-- | mlir/lib/Dialect/VectorOps/VectorOps.cpp | 18 | ||||
| -rw-r--r-- | mlir/test/Dialect/VectorOps/invalid.mlir | 4 | ||||
| -rw-r--r-- | mlir/test/Dialect/VectorOps/ops.mlir | 12 |
4 files changed, 24 insertions, 28 deletions
diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td index 34c2fa97e53..c78334dd54a 100644 --- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td +++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td @@ -165,27 +165,25 @@ def Vector_ContractionOp : def Vector_BroadcastOp : Vector_Op<"broadcast", [NoSideEffect, PredOpTrait<"source operand and result have same element type", - TCresVTEtIsSameAsOpBase<0, 0>>, - PredOpTrait<"dest operand and result have same type", - TCresIsSameAsOpBase<0, 1>>]>, - Arguments<(ins AnyType:$source, AnyVector:$dest)>, + TCresVTEtIsSameAsOpBase<0, 0>>]>, + Arguments<(ins AnyType:$source)>, Results<(outs AnyVector:$vector)> { let summary = "broadcast operation"; let description = [{ - Broadcasts the scalar or k-D vector value in the source to the n-D - destination vector of a proper shape such that the broadcast makes sense. + Broadcasts the scalar or k-D vector value in the source operand + to a n-D result vector such that the broadcast makes sense. Examples: ``` %0 = constant 0.0 : f32 - %1 = vector.broadcast %0, %x : f32 into vector<16xf32> - %2 = vector.broadcast %1, %y : vector<16xf32> into vector<4x16xf32> + %1 = vector.broadcast %0 : f32 to vector<16xf32> + %2 = vector.broadcast %1 : vector<16xf32> to vector<4x16xf32> ``` }]; let extraClassDeclaration = [{ Type getSourceType() { return source()->getType(); } - VectorType getDestVectorType() { - return dest()->getType().cast<VectorType>(); + VectorType getVectorType() { + return vector()->getType().cast<VectorType>(); } }]; } diff --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp index d09fd0fc2f2..fe320b91439 100644 --- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp @@ -373,14 +373,14 @@ static LogicalResult verify(ExtractElementOp op) { //===----------------------------------------------------------------------===// static void print(OpAsmPrinter &p, BroadcastOp op) { - p << op.getOperationName() << " " << *op.source() << ", " << *op.dest(); + p << op.getOperationName() << " " << *op.source(); p << " : " << op.getSourceType(); - p << " into " << op.getDestVectorType(); + p << " to " << op.getVectorType(); } static LogicalResult verify(BroadcastOp op) { VectorType srcVectorType = op.getSourceType().dyn_cast<VectorType>(); - VectorType dstVectorType = op.getDestVectorType(); + VectorType dstVectorType = op.getVectorType(); // Scalar to vector broadcast is always valid. A vector // to vector broadcast needs some additional checking. if (srcVectorType) { @@ -397,16 +397,14 @@ static LogicalResult verify(BroadcastOp op) { static ParseResult parseBroadcastOp(OpAsmParser &parser, OperationState &result) { - OpAsmParser::OperandType source, dest; + OpAsmParser::OperandType source; Type sourceType; - VectorType destType; - return failure(parser.parseOperand(source) || parser.parseComma() || - parser.parseOperand(dest) || + VectorType vectorType; + return failure(parser.parseOperand(source) || parser.parseColonType(sourceType) || - parser.parseKeywordType("into", destType) || + parser.parseKeywordType("to", vectorType) || parser.resolveOperand(source, sourceType, result.operands) || - parser.resolveOperand(dest, destType, result.operands) || - parser.addTypeToList(destType, result.types)); + parser.addTypeToList(vectorType, result.types)); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/VectorOps/invalid.mlir b/mlir/test/Dialect/VectorOps/invalid.mlir index 92e956ef29a..d672b1bf140 100644 --- a/mlir/test/Dialect/VectorOps/invalid.mlir +++ b/mlir/test/Dialect/VectorOps/invalid.mlir @@ -2,9 +2,9 @@ // ----- -func @broadcast_rank_too_high(%arg0: vector<4x4xf32>, %arg1: vector<4xf32>) { +func @broadcast_rank_too_high(%arg0: vector<4x4xf32>) { // expected-error@+1 {{source rank higher than destination rank}} - %2 = vector.broadcast %arg0, %arg1 : vector<4x4xf32> into vector<4xf32> + %1 = vector.broadcast %arg0 : vector<4x4xf32> to vector<4xf32> } // ----- diff --git a/mlir/test/Dialect/VectorOps/ops.mlir b/mlir/test/Dialect/VectorOps/ops.mlir index 51dbc4f0435..d167559ac0c 100644 --- a/mlir/test/Dialect/VectorOps/ops.mlir +++ b/mlir/test/Dialect/VectorOps/ops.mlir @@ -23,12 +23,12 @@ func @vector_transfer_ops(%arg0: memref<?x?xf32>) { } // CHECK-LABEL: @vector_broadcast -func @vector_broadcast(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>) { - // CHECK: vector.broadcast %{{.*}}, %{{.*}} : f32 into vector<16xf32> - %0 = vector.broadcast %a, %b : f32 into vector<16xf32> - // CHECK-NEXT: vector.broadcast %{{.*}}, %{{.*}} : vector<16xf32> into vector<8x16xf32> - %1 = vector.broadcast %b, %c : vector<16xf32> into vector<8x16xf32> - return +func @vector_broadcast(%a: f32, %b: vector<16xf32>) -> vector<8x16xf32> { + // CHECK: vector.broadcast %{{.*}} : f32 to vector<16xf32> + %0 = vector.broadcast %a : f32 to vector<16xf32> + // CHECK-NEXT: vector.broadcast %{{.*}} : vector<16xf32> to vector<8x16xf32> + %1 = vector.broadcast %b : vector<16xf32> to vector<8x16xf32> + return %1 : vector<8x16xf32> } // CHECK-LABEL: @extractelement |

