diff options
author | Aart Bik <ajcbik@google.com> | 2019-12-06 12:38:52 -0800 |
---|---|---|
committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-12-06 12:39:25 -0800 |
commit | d37f27251f13ee6780267683fb9c4e69aa9c15a6 (patch) | |
tree | e45094c56cc3b9696f99a95750e22e7574355eb4 | |
parent | be3ed14658721aa458eeb887db5a5fc4b5a5fc1e (diff) | |
download | bcm5719-llvm-d37f27251f13ee6780267683fb9c4e69aa9c15a6.tar.gz bcm5719-llvm-d37f27251f13ee6780267683fb9c4e69aa9c15a6.zip |
[VecOps] Rename vector.[insert|extract]element to just vector.[insert|extract]
Since these operations lower to [insert|extract][element|value] at LLVM
dialect level, neither element nor value would correctly reflect the meaning.
PiperOrigin-RevId: 284240727
-rw-r--r-- | mlir/include/mlir/Dialect/VectorOps/VectorOps.td | 20 | ||||
-rw-r--r-- | mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp | 6 | ||||
-rw-r--r-- | mlir/lib/Dialect/VectorOps/VectorOps.cpp | 44 | ||||
-rw-r--r-- | mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir | 4 | ||||
-rw-r--r-- | mlir/test/Dialect/VectorOps/invalid.mlir | 44 | ||||
-rw-r--r-- | mlir/test/Dialect/VectorOps/ops.mlir | 32 |
6 files changed, 72 insertions, 78 deletions
diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td index ebeecfbb715..6c2b4e6bb16 100644 --- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td +++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td @@ -216,21 +216,21 @@ def Vector_BroadcastOp : }]; } -def Vector_ExtractElementOp : - Vector_Op<"extractelement", [NoSideEffect, +def Vector_ExtractOp : + Vector_Op<"extract", [NoSideEffect, PredOpTrait<"operand and result have same element type", TCresVTEtIsSameAsOpBase<0, 0>>]>, Arguments<(ins AnyVector:$vector, I32ArrayAttr:$position)>, Results<(outs AnyType)> { - let summary = "extractelement operation"; + let summary = "extract operation"; let description = [{ Takes an n-D vector and a k-D position and extracts the (n-k)-D vector at the proper position. Degenerates to an element type in the 0-D case. Examples: ``` - %1 = vector.extractelement %0[3]: vector<4x8x16xf32> - %2 = vector.extractelement %0[3, 3, 3]: vector<4x8x16xf32> + %1 = vector.extract %0[3]: vector<4x8x16xf32> + %2 = vector.extract %0[3, 3, 3]: vector<4x8x16xf32> ``` }]; let builders = [OpBuilder< @@ -243,15 +243,15 @@ def Vector_ExtractElementOp : }]; } -def Vector_InsertElementOp : - Vector_Op<"insertelement", [NoSideEffect, +def Vector_InsertOp : + Vector_Op<"insert", [NoSideEffect, PredOpTrait<"source operand and result have same element type", TCresVTEtIsSameAsOpBase<0, 0>>, PredOpTrait<"dest operand and result have same type", TCresIsSameAsOpBase<0, 1>>]>, Arguments<(ins AnyType:$source, AnyVector:$dest, I32ArrayAttr:$position)>, Results<(outs AnyVector)> { - let summary = "insertelement operation"; + let summary = "insert operation"; let description = [{ Takes an n-D source vector, an (n+k)-D destination vector and a k-D position and inserts the n-D source into the (n+k)-D destination at the proper @@ -259,9 +259,9 @@ def Vector_InsertElementOp : Examples: ``` - %2 = vector.insertelement %0, %1[3 : i32]: + %2 = vector.insert %0, %1[3 : i32]: vector<8x16xf32> into vector<4x8x16xf32> - %5 = vector.insertelement %3, %4[3 : i32, 3 : i32, 3 : i32]: + %5 = vector.insert %3, %4[3 : i32, 3 : i32, 3 : i32]: f32 into vector<4x8x16xf32> ``` }]; diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index c40c7c5242a..8adc415f820 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -238,15 +238,15 @@ class VectorExtractElementOpConversion : public LLVMOpLowering { public: explicit VectorExtractElementOpConversion(MLIRContext *context, LLVMTypeConverter &typeConverter) - : LLVMOpLowering(vector::ExtractElementOp::getOperationName(), context, + : LLVMOpLowering(vector::ExtractOp::getOperationName(), context, typeConverter) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); - auto adaptor = vector::ExtractElementOpOperandAdaptor(operands); - auto extractOp = cast<vector::ExtractElementOp>(op); + auto adaptor = vector::ExtractOpOperandAdaptor(operands); + auto extractOp = cast<vector::ExtractOp>(op); auto vectorType = extractOp.vector()->getType().cast<VectorType>(); auto resultType = extractOp.getResult()->getType(); auto llvmResultType = lowering.convertType(resultType); diff --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp index 65441674165..c1e88aa0076 100644 --- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp @@ -324,35 +324,33 @@ SmallVector<AffineMap, 4> ContractionOp::getIndexingMaps() { } //===----------------------------------------------------------------------===// -// ExtractElementOp +// ExtractOp //===----------------------------------------------------------------------===// -static Type inferExtractElementOpResultType(VectorType vectorType, - ArrayAttr position) { +static Type inferExtractOpResultType(VectorType vectorType, + ArrayAttr position) { if (static_cast<int64_t>(position.size()) == vectorType.getRank()) return vectorType.getElementType(); return VectorType::get(vectorType.getShape().drop_front(position.size()), vectorType.getElementType()); } -void vector::ExtractElementOp::build(Builder *builder, OperationState &result, - Value *source, - ArrayRef<int32_t> position) { +void vector::ExtractOp::build(Builder *builder, OperationState &result, + Value *source, ArrayRef<int32_t> position) { result.addOperands(source); auto positionAttr = builder->getI32ArrayAttr(position); - result.addTypes(inferExtractElementOpResultType( - source->getType().cast<VectorType>(), positionAttr)); + result.addTypes(inferExtractOpResultType(source->getType().cast<VectorType>(), + positionAttr)); result.addAttribute(getPositionAttrName(), positionAttr); } -static void print(OpAsmPrinter &p, vector::ExtractElementOp op) { +static void print(OpAsmPrinter &p, vector::ExtractOp op) { p << op.getOperationName() << " " << *op.vector() << op.position(); p.printOptionalAttrDict(op.getAttrs(), {"position"}); p << " : " << op.vector()->getType(); } -static ParseResult parseExtractElementOp(OpAsmParser &parser, - OperationState &result) { +static ParseResult parseExtractOp(OpAsmParser &parser, OperationState &result) { llvm::SMLoc attributeLoc, typeLoc; SmallVector<NamedAttribute, 4> attrs; OpAsmParser::OperandType vector; @@ -375,13 +373,13 @@ static ParseResult parseExtractElementOp(OpAsmParser &parser, attributeLoc, "expected position attribute of rank smaller than vector rank"); - Type resType = inferExtractElementOpResultType(vectorType, positionAttr); + Type resType = inferExtractOpResultType(vectorType, positionAttr); result.attributes = attrs; return failure(parser.resolveOperand(vector, type, result.operands) || parser.addTypeToList(resType, result.types)); } -static LogicalResult verify(vector::ExtractElementOp op) { +static LogicalResult verify(vector::ExtractOp op) { auto positionAttr = op.position().getValue(); if (positionAttr.empty()) return op.emitOpError("expected non-empty position attribute"); @@ -447,29 +445,26 @@ static ParseResult parseBroadcastOp(OpAsmParser &parser, } //===----------------------------------------------------------------------===// -// InsertElementOp +// InsertOp //===----------------------------------------------------------------------===// -void InsertElementOp::build(Builder *builder, OperationState &result, - Value *source, Value *dest, - ArrayRef<int32_t> position) { +void InsertOp::build(Builder *builder, OperationState &result, Value *source, + Value *dest, ArrayRef<int32_t> position) { result.addOperands({source, dest}); auto positionAttr = builder->getI32ArrayAttr(position); result.addTypes(dest->getType()); result.addAttribute(getPositionAttrName(), positionAttr); } -static void print(OpAsmPrinter &p, InsertElementOp op) { +static void print(OpAsmPrinter &p, InsertOp op) { p << op.getOperationName() << " " << *op.source() << ", " << *op.dest() << op.position(); - p.printOptionalAttrDict(op.getAttrs(), - {InsertElementOp::getPositionAttrName()}); + p.printOptionalAttrDict(op.getAttrs(), {InsertOp::getPositionAttrName()}); p << " : " << op.getSourceType(); p << " into " << op.getDestVectorType(); } -static ParseResult parseInsertElementOp(OpAsmParser &parser, - OperationState &result) { +static ParseResult parseInsertOp(OpAsmParser &parser, OperationState &result) { SmallVector<NamedAttribute, 4> attrs; OpAsmParser::OperandType source, dest; Type sourceType; @@ -477,8 +472,7 @@ static ParseResult parseInsertElementOp(OpAsmParser &parser, Attribute attr; return failure(parser.parseOperand(source) || parser.parseComma() || parser.parseOperand(dest) || - parser.parseAttribute(attr, - InsertElementOp::getPositionAttrName(), + parser.parseAttribute(attr, InsertOp::getPositionAttrName(), result.attributes) || parser.parseOptionalAttrDict(attrs) || parser.parseColonType(sourceType) || @@ -488,7 +482,7 @@ static ParseResult parseInsertElementOp(OpAsmParser &parser, parser.addTypeToList(destType, result.types)); } -static LogicalResult verify(InsertElementOp op) { +static LogicalResult verify(InsertOp op) { auto positionAttr = op.position().getValue(); if (positionAttr.empty()) return op.emitOpError("expected non-empty position attribute"); diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index b07a8634da4..8f66b44b094 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -231,7 +231,7 @@ func @outerproduct_add(%arg0: vector<2xf32>, %arg1: vector<3xf32>, %arg2: vector // CHECK: llvm.return {{.*}} : !llvm<"[2 x <3 x float>]"> func @extract_vec_2d_from_vec_3d(%arg0: vector<4x3x16xf32>) -> vector<3x16xf32> { - %0 = vector.extractelement %arg0[0 : i32]: vector<4x3x16xf32> + %0 = vector.extract %arg0[0 : i32]: vector<4x3x16xf32> return %0 : vector<3x16xf32> } // CHECK-LABEL: extract_vec_2d_from_vec_3d @@ -239,7 +239,7 @@ func @extract_vec_2d_from_vec_3d(%arg0: vector<4x3x16xf32>) -> vector<3x16xf32> // CHECK: llvm.return %{{.*}} : !llvm<"[3 x <16 x float>]"> func @extract_element_from_vec_3d(%arg0: vector<4x3x16xf32>) -> f32 { - %0 = vector.extractelement %arg0[0 : i32, 0 : i32, 0 : i32]: vector<4x3x16xf32> + %0 = vector.extract %arg0[0 : i32, 0 : i32, 0 : i32]: vector<4x3x16xf32> return %0 : f32 } // CHECK-LABEL: extract_element_from_vec_3d diff --git a/mlir/test/Dialect/VectorOps/invalid.mlir b/mlir/test/Dialect/VectorOps/invalid.mlir index b70fc23ef1d..a0faa37ed03 100644 --- a/mlir/test/Dialect/VectorOps/invalid.mlir +++ b/mlir/test/Dialect/VectorOps/invalid.mlir @@ -31,79 +31,79 @@ func @broadcast_dim2_mismatch(%arg0: vector<4x8xf32>) { // ----- -func @extract_element_vector_type(%arg0: index) { +func @extract_vector_type(%arg0: index) { // expected-error@+1 {{expected vector type}} - %1 = vector.extractelement %arg0[] : index + %1 = vector.extract %arg0[] : index } // ----- -func @extractelement_position_empty(%arg0: vector<4x8x16xf32>) { +func @extract_position_empty(%arg0: vector<4x8x16xf32>) { // expected-error@+1 {{expected non-empty position attribute}} - %1 = vector.extractelement %arg0[] : vector<4x8x16xf32> + %1 = vector.extract %arg0[] : vector<4x8x16xf32> } // ----- -func @extractelement_position_rank_overflow(%arg0: vector<4x8x16xf32>) { +func @extract_position_rank_overflow(%arg0: vector<4x8x16xf32>) { // expected-error@+1 {{expected position attribute of rank smaller than vector}} - %1 = vector.extractelement %arg0[0 : i32, 0 : i32, 0 : i32, 0 : i32] : vector<4x8x16xf32> + %1 = vector.extract %arg0[0 : i32, 0 : i32, 0 : i32, 0 : i32] : vector<4x8x16xf32> } // ----- -func @extractelement_position_rank_overflow_generic(%arg0: vector<4x8x16xf32>) { +func @extract_position_rank_overflow_generic(%arg0: vector<4x8x16xf32>) { // expected-error@+1 {{expected position attribute of rank smaller than vector}} - %1 = "vector.extractelement" (%arg0) { position = [0 : i32, 0 : i32, 0 : i32, 0 : i32] } : (vector<4x8x16xf32>) -> (vector<16xf32>) + %1 = "vector.extract" (%arg0) { position = [0 : i32, 0 : i32, 0 : i32, 0 : i32] } : (vector<4x8x16xf32>) -> (vector<16xf32>) } // ----- -func @extractelement_position_overflow(%arg0: vector<4x8x16xf32>) { +func @extract_position_overflow(%arg0: vector<4x8x16xf32>) { // expected-error@+1 {{expected position attribute #2 to be a non-negative integer smaller than the corresponding vector dimension}} - %1 = vector.extractelement %arg0[0 : i32, 43 : i32, 0 : i32] : vector<4x8x16xf32> + %1 = vector.extract %arg0[0 : i32, 43 : i32, 0 : i32] : vector<4x8x16xf32> } // ----- -func @extractelement_position_overflow(%arg0: vector<4x8x16xf32>) { +func @extract_position_overflow(%arg0: vector<4x8x16xf32>) { // expected-error@+1 {{expected position attribute #3 to be a non-negative integer smaller than the corresponding vector dimension}} - %1 = vector.extractelement %arg0[0 : i32, 0 : i32, -1 : i32] : vector<4x8x16xf32> + %1 = vector.extract %arg0[0 : i32, 0 : i32, -1 : i32] : vector<4x8x16xf32> } // ----- -func @insert_element_vector_type(%a: f32, %b: vector<4x8x16xf32>) { +func @insert_vector_type(%a: f32, %b: vector<4x8x16xf32>) { // expected-error@+1 {{expected non-empty position attribute}} - %1 = vector.insertelement %a, %b[] : f32 into vector<4x8x16xf32> + %1 = vector.insert %a, %b[] : f32 into vector<4x8x16xf32> } // ----- -func @insert_element_vector_type(%a: f32, %b: vector<4x8x16xf32>) { +func @insert_vector_type(%a: f32, %b: vector<4x8x16xf32>) { // expected-error@+1 {{expected position attribute of rank smaller than dest vector rank}} - %1 = vector.insertelement %a, %b[3 : i32,3 : i32,3 : i32,3 : i32,3 : i32,3 : i32] : f32 into vector<4x8x16xf32> + %1 = vector.insert %a, %b[3 : i32,3 : i32,3 : i32,3 : i32,3 : i32,3 : i32] : f32 into vector<4x8x16xf32> } // ----- -func @insert_element_vector_type(%a: vector<4xf32>, %b: vector<4x8x16xf32>) { +func @insert_vector_type(%a: vector<4xf32>, %b: vector<4x8x16xf32>) { // expected-error@+1 {{expected position attribute rank + source rank to match dest vector rank}} - %1 = vector.insertelement %a, %b[3 : i32] : vector<4xf32> into vector<4x8x16xf32> + %1 = vector.insert %a, %b[3 : i32] : vector<4xf32> into vector<4x8x16xf32> } // ----- -func @insert_element_vector_type(%a: f32, %b: vector<4x8x16xf32>) { +func @insert_vector_type(%a: f32, %b: vector<4x8x16xf32>) { // expected-error@+1 {{expected position attribute rank to match the dest vector rank}} - %1 = vector.insertelement %a, %b[3 : i32,3 : i32] : f32 into vector<4x8x16xf32> + %1 = vector.insert %a, %b[3 : i32,3 : i32] : f32 into vector<4x8x16xf32> } // ----- -func @insertelement_position_overflow(%a: f32, %b: vector<4x8x16xf32>) { +func @insert_position_overflow(%a: f32, %b: vector<4x8x16xf32>) { // expected-error@+1 {{expected position attribute #3 to be a non-negative integer smaller than the corresponding dest vector dimension}} - %1 = vector.insertelement %a, %b[0 : i32, 0 : i32, -1 : i32] : f32 into vector<4x8x16xf32> + %1 = vector.insert %a, %b[0 : i32, 0 : i32, -1 : i32] : f32 into vector<4x8x16xf32> } // ----- diff --git a/mlir/test/Dialect/VectorOps/ops.mlir b/mlir/test/Dialect/VectorOps/ops.mlir index c1c911098ae..b98e749765c 100644 --- a/mlir/test/Dialect/VectorOps/ops.mlir +++ b/mlir/test/Dialect/VectorOps/ops.mlir @@ -35,25 +35,25 @@ func @vector_broadcast(%a: f32, %b: vector<16xf32>, %c: vector<1x16xf32>, %d: ve return %3 : vector<8x16xf32> } -// CHECK-LABEL: @extractelement -func @extractelement(%arg0: vector<4x8x16xf32>) -> (vector<8x16xf32>, vector<16xf32>, f32) { - // CHECK: vector.extractelement {{.*}}[3 : i32] : vector<4x8x16xf32> - %1 = vector.extractelement %arg0[3 : i32] : vector<4x8x16xf32> - // CHECK-NEXT: vector.extractelement {{.*}}[3 : i32, 3 : i32] : vector<4x8x16xf32> - %2 = vector.extractelement %arg0[3 : i32, 3 : i32] : vector<4x8x16xf32> - // CHECK-NEXT: vector.extractelement {{.*}}[3 : i32, 3 : i32, 3 : i32] : vector<4x8x16xf32> - %3 = vector.extractelement %arg0[3 : i32, 3 : i32, 3 : i32] : vector<4x8x16xf32> +// CHECK-LABEL: @extract +func @extract(%arg0: vector<4x8x16xf32>) -> (vector<8x16xf32>, vector<16xf32>, f32) { + // CHECK: vector.extract {{.*}}[3 : i32] : vector<4x8x16xf32> + %1 = vector.extract %arg0[3 : i32] : vector<4x8x16xf32> + // CHECK-NEXT: vector.extract {{.*}}[3 : i32, 3 : i32] : vector<4x8x16xf32> + %2 = vector.extract %arg0[3 : i32, 3 : i32] : vector<4x8x16xf32> + // CHECK-NEXT: vector.extract {{.*}}[3 : i32, 3 : i32, 3 : i32] : vector<4x8x16xf32> + %3 = vector.extract %arg0[3 : i32, 3 : i32, 3 : i32] : vector<4x8x16xf32> return %1, %2, %3 : vector<8x16xf32>, vector<16xf32>, f32 } -// CHECK-LABEL: @insertelement -func @insertelement(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>, %res: vector<4x8x16xf32>) { - // CHECK: vector.insertelement %{{.*}}, %{{.*}}[3 : i32] : vector<8x16xf32> into vector<4x8x16xf32> - %1 = vector.insertelement %c, %res[3 : i32] : vector<8x16xf32> into vector<4x8x16xf32> - // CHECK: vector.insertelement %{{.*}}, %{{.*}}[3 : i32, 3 : i32] : vector<16xf32> into vector<4x8x16xf32> - %2 = vector.insertelement %b, %res[3 : i32, 3 : i32] : vector<16xf32> into vector<4x8x16xf32> - // CHECK: vector.insertelement %{{.*}}, %{{.*}}[3 : i32, 3 : i32, 3 : i32] : f32 into vector<4x8x16xf32> - %3 = vector.insertelement %a, %res[3 : i32, 3 : i32, 3 : i32] : f32 into vector<4x8x16xf32> +// CHECK-LABEL: @insert +func @insert(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>, %res: vector<4x8x16xf32>) { + // CHECK: vector.insert %{{.*}}, %{{.*}}[3 : i32] : vector<8x16xf32> into vector<4x8x16xf32> + %1 = vector.insert %c, %res[3 : i32] : vector<8x16xf32> into vector<4x8x16xf32> + // CHECK: vector.insert %{{.*}}, %{{.*}}[3 : i32, 3 : i32] : vector<16xf32> into vector<4x8x16xf32> + %2 = vector.insert %b, %res[3 : i32, 3 : i32] : vector<16xf32> into vector<4x8x16xf32> + // CHECK: vector.insert %{{.*}}, %{{.*}}[3 : i32, 3 : i32, 3 : i32] : f32 into vector<4x8x16xf32> + %3 = vector.insert %a, %res[3 : i32, 3 : i32, 3 : i32] : f32 into vector<4x8x16xf32> return } |