diff options
-rw-r--r-- | mlir/include/mlir/Dialect/VectorOps/VectorOps.td | 59 | ||||
-rw-r--r-- | mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp | 53 | ||||
-rw-r--r-- | mlir/lib/Dialect/VectorOps/VectorOps.cpp | 74 | ||||
-rw-r--r-- | mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir | 20 | ||||
-rw-r--r-- | mlir/test/Dialect/VectorOps/invalid.mlir | 26 | ||||
-rw-r--r-- | mlir/test/Dialect/VectorOps/ops.mlir | 18 |
6 files changed, 248 insertions, 2 deletions
diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td index 883e1bcfff7..eb05821952d 100644 --- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td +++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td @@ -267,6 +267,33 @@ def Vector_ShuffleOp : }]; } +def Vector_ExtractElementOp : + Vector_Op<"extractelement", [NoSideEffect, + PredOpTrait<"operand and result have same element type", + TCresVTEtIsSameAsOpBase<0, 0>>]>, + Arguments<(ins AnyVector:$vector, Index:$position)>, + Results<(outs AnyType)> { + let summary = "extractelement operation"; + let description = [{ + Takes an 1-D vector and a dynamic index position and extracts the + scalar at that position. Note that this instruction resembles + vector.extract, but is restricted to 1-D vectors and relaxed + to dynamic indices. It is meant to be closer to LLVM's version: + https://llvm.org/docs/LangRef.html#extractelement-instruction + + Example: + ``` + %c = constant 15 : i32 + %1 = vector.extractelement %0[%c : i32]: vector<16xf32> + ``` + }]; + let extraClassDeclaration = [{ + VectorType getVectorType() { + return vector()->getType().cast<VectorType>(); + } + }]; +} + def Vector_ExtractOp : Vector_Op<"extract", [NoSideEffect, PredOpTrait<"operand and result have same element type", @@ -346,6 +373,38 @@ def Vector_ExtractSlicesOp : }]; } +def Vector_InsertElementOp : + Vector_Op<"insertelement", [NoSideEffect, + PredOpTrait<"source operand and result have same element type", + TCresVTEtIsSameAsOpBase<0, 0>>, + PredOpTrait<"dest operand and result have same type", + TCresIsSameAsOpBase<0, 1>>]>, + Arguments<(ins AnyType:$source, AnyVector:$dest, Index:$position)>, + Results<(outs AnyVector)> { + let summary = "insertelement operation"; + let description = [{ + Takes a scalar source, an 1-D destination vector and a dynamic index + position and inserts the source into the destination at the proper + position. Note that this instruction resembles vector.insert, but + is restricted to 1-D vectors and relaxed to dynamic indices. It is + meant to be closer to LLVM's version: + https://llvm.org/docs/LangRef.html#insertelement-instruction + + Example: + ``` + %c = constant 15 : i32 + %f = constant 0.0f : f32 + %1 = vector.insertelement %f, %0[%c : i32]: vector<16xf32> + ``` + }]; + let extraClassDeclaration = [{ + Type getSourceType() { return source()->getType(); } + VectorType getDestVectorType() { + return dest()->getType().cast<VectorType>(); + } + }]; +} + def Vector_InsertOp : Vector_Op<"insert", [NoSideEffect, PredOpTrait<"source operand and result have same element type", diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index d4c27a69fb5..71bed9516ca 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -300,6 +300,31 @@ public: } }; +class VectorExtractElementOpConversion : public LLVMOpLowering { +public: + explicit VectorExtractElementOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : LLVMOpLowering(vector::ExtractElementOp::getOperationName(), context, + typeConverter) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value *> operands, + ConversionPatternRewriter &rewriter) const override { + auto adaptor = vector::ExtractElementOpOperandAdaptor(operands); + auto extractEltOp = cast<vector::ExtractElementOp>(op); + auto vectorType = extractEltOp.getVectorType(); + auto llvmType = lowering.convertType(vectorType.getElementType()); + + // Bail if result type cannot be lowered. + if (!llvmType) + return matchFailure(); + + rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( + op, llvmType, adaptor.vector(), adaptor.position()); + return matchSuccess(); + } +}; + class VectorExtractOpConversion : public LLVMOpLowering { public: explicit VectorExtractOpConversion(MLIRContext *context, @@ -355,6 +380,31 @@ public: } }; +class VectorInsertElementOpConversion : public LLVMOpLowering { +public: + explicit VectorInsertElementOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : LLVMOpLowering(vector::InsertElementOp::getOperationName(), context, + typeConverter) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value *> operands, + ConversionPatternRewriter &rewriter) const override { + auto adaptor = vector::InsertElementOpOperandAdaptor(operands); + auto insertEltOp = cast<vector::InsertElementOp>(op); + auto vectorType = insertEltOp.getDestVectorType(); + auto llvmType = lowering.convertType(vectorType); + + // Bail if result type cannot be lowered. + if (!llvmType) + return matchFailure(); + + rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( + op, llvmType, adaptor.dest(), adaptor.source(), adaptor.position()); + return matchSuccess(); + } +}; + class VectorInsertOpConversion : public LLVMOpLowering { public: explicit VectorInsertOpConversion(MLIRContext *context, @@ -566,7 +616,8 @@ public: void mlir::populateVectorToLLVMConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { patterns.insert<VectorBroadcastOpConversion, VectorShuffleOpConversion, - VectorExtractOpConversion, VectorInsertOpConversion, + VectorExtractElementOpConversion, VectorExtractOpConversion, + VectorInsertElementOpConversion, VectorInsertOpConversion, VectorOuterProductOpConversion, VectorTypeCastOpConversion>( converter.getDialect()->getContext(), converter); } diff --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp index 2dfa4568a3e..fc8abd710e9 100644 --- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp @@ -347,6 +347,42 @@ SmallVector<AffineMap, 4> ContractionOp::getIndexingMaps() { } //===----------------------------------------------------------------------===// +// ExtractElementOp +//===----------------------------------------------------------------------===// + +static void print(OpAsmPrinter &p, vector::ExtractElementOp op) { + p << op.getOperationName() << " " << *op.vector() << "[" << *op.position() + << " : " << op.position()->getType() << "]"; + p.printOptionalAttrDict(op.getAttrs()); + p << " : " << op.vector()->getType(); +} + +static ParseResult parseExtractElementOp(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::OperandType vector, position; + Type positionType; + VectorType vectorType; + if (parser.parseOperand(vector) || parser.parseLSquare() || + parser.parseOperand(position) || parser.parseColonType(positionType) || + parser.parseRSquare() || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(vectorType)) + return failure(); + Type resultType = vectorType.getElementType(); + return failure( + parser.resolveOperand(vector, vectorType, result.operands) || + parser.resolveOperand(position, positionType, result.operands) || + parser.addTypeToList(resultType, result.types)); +} + +static LogicalResult verify(vector::ExtractElementOp op) { + VectorType vectorType = op.getVectorType(); + if (vectorType.getRank() != 1) + return op.emitOpError("expected 1-D vector"); + return success(); +} + +//===----------------------------------------------------------------------===// // ExtractOp //===----------------------------------------------------------------------===// @@ -685,6 +721,44 @@ static ParseResult parseShuffleOp(OpAsmParser &parser, OperationState &result) { } //===----------------------------------------------------------------------===// +// InsertElementOp +//===----------------------------------------------------------------------===// + +static void print(OpAsmPrinter &p, InsertElementOp op) { + p << op.getOperationName() << " " << *op.source() << ", " << *op.dest() << "[" + << *op.position() << " : " << op.position()->getType() << "]"; + p.printOptionalAttrDict(op.getAttrs()); + p << " : " << op.dest()->getType(); +} + +static ParseResult parseInsertElementOp(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::OperandType source, dest, position; + Type positionType; + VectorType destType; + if (parser.parseOperand(source) || parser.parseComma() || + parser.parseOperand(dest) || parser.parseLSquare() || + parser.parseOperand(position) || parser.parseColonType(positionType) || + parser.parseRSquare() || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(destType)) + return failure(); + Type sourceType = destType.getElementType(); + return failure( + parser.resolveOperand(source, sourceType, result.operands) || + parser.resolveOperand(dest, destType, result.operands) || + parser.resolveOperand(position, positionType, result.operands) || + parser.addTypeToList(destType, result.types)); +} + +static LogicalResult verify(InsertElementOp op) { + auto dstVectorType = op.getDestVectorType(); + if (dstVectorType.getRank() != 1) + return op.emitOpError("expected 1-D vector"); + return success(); +} + +//===----------------------------------------------------------------------===// // InsertOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index 0c4b23f2067..73aba05b3b3 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -280,6 +280,16 @@ func @shuffle_2D(%a: vector<1x4xf32>, %b: vector<2x4xf32>) -> vector<3x4xf32> { // CHECK: %[[i3:.*]] = llvm.insertvalue %[[e3]], %[[i2]][2] : !llvm<"[3 x <4 x float>]"> // CHECK: llvm.return %[[i3]] : !llvm<"[3 x <4 x float>]"> +func @extract_element(%arg0: vector<16xf32>) -> f32 { + %0 = constant 15 : index + %1 = vector.extractelement %arg0[%0 : index]: vector<16xf32> + return %1 : f32 +} +// CHECK-LABEL: extract_element(%arg0: !llvm<"<16 x float>">) +// CHECK: %[[c:.*]] = llvm.mlir.constant(15 : index) : !llvm.i64 +// CHECK: %[[x:.*]] = llvm.extractelement %arg0[%[[c]] : !llvm.i64] : !llvm<"<16 x float>"> +// CHECK: llvm.return %[[x]] : !llvm.float + func @extract_element_from_vec_1d(%arg0: vector<16xf32>) -> f32 { %0 = vector.extract %arg0[15 : i32]: vector<16xf32> return %0 : f32 @@ -315,6 +325,16 @@ func @extract_element_from_vec_3d(%arg0: vector<4x3x16xf32>) -> f32 { // CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i32] : !llvm<"<16 x float>"> // CHECK: llvm.return {{.*}} : !llvm.float +func @insert_element(%arg0: f32, %arg1: vector<4xf32>) -> vector<4xf32> { + %0 = constant 3 : index + %1 = vector.insertelement %arg0, %arg1[%0 : index] : vector<4xf32> + return %1 : vector<4xf32> +} +// CHECK-LABEL: insert_element(%arg0: !llvm.float, %arg1: !llvm<"<4 x float>">) +// CHECK: %[[c:.*]] = llvm.mlir.constant(3 : index) : !llvm.i64 +// CHECK: %[[x:.*]] = llvm.insertelement %arg0, %arg1[%[[c]] : !llvm.i64] : !llvm<"<4 x float>"> +// CHECK: llvm.return %[[x]] : !llvm<"<4 x float>"> + func @insert_element_into_vec_1d(%arg0: f32, %arg1: vector<4xf32>) -> vector<4xf32> { %0 = vector.insert %arg0, %arg1[3 : i32] : f32 into vector<4xf32> return %0 : vector<4xf32> diff --git a/mlir/test/Dialect/VectorOps/invalid.mlir b/mlir/test/Dialect/VectorOps/invalid.mlir index 892c10cd20d..c04c8ea486a 100644 --- a/mlir/test/Dialect/VectorOps/invalid.mlir +++ b/mlir/test/Dialect/VectorOps/invalid.mlir @@ -60,12 +60,20 @@ func @shuffle_index_out_of_range(%arg0: vector<2xf32>, %arg1: vector<2xf32>) { // ----- func @shuffle_empty_mask(%arg0: vector<2xf32>, %arg1: vector<2xf32>) { - // expected-error@+1 {{custom op 'vector.shuffle' invalid mask length}} + // expected-error@+1 {{'vector.shuffle' invalid mask length}} %1 = vector.shuffle %arg0, %arg1 [] : vector<2xf32>, vector<2xf32> } // ----- +func @extract_element(%arg0: vector<4x4xf32>) { + %c = constant 3 : index + // expected-error@+1 {{'vector.extractelement' op expected 1-D vector}} + %1 = vector.extractelement %arg0[%c : index] : vector<4x4xf32> +} + +// ----- + func @extract_vector_type(%arg0: index) { // expected-error@+1 {{expected vector type}} %1 = vector.extract %arg0[] : index @@ -115,6 +123,22 @@ func @extract_position_overflow(%arg0: vector<4x8x16xf32>) { // ----- +func @insert_element(%arg0: f32, %arg1: vector<4x4xf32>) { + %c = constant 3 : index + // expected-error@+1 {{'vector.insertelement' op expected 1-D vector}} + %0 = vector.insertelement %arg0, %arg1[%c : index] : vector<4x4xf32> +} + +// ----- + +func @insert_element_wrong_type(%arg0: i32, %arg1: vector<4xf32>) { + %c = constant 3 : index + // expected-error@+1 {{'vector.insertelement' op failed to verify that source operand and result have same element type}} + %0 = "vector.insertelement" (%arg0, %arg1, %c) : (i32, vector<4xf32>, index) -> (vector<4xf32>) +} + +// ----- + func @insert_vector_type(%a: f32, %b: vector<4x8x16xf32>) { // expected-error@+1 {{expected non-empty position attribute}} %1 = vector.insert %a, %b[] : f32 into vector<4x8x16xf32> diff --git a/mlir/test/Dialect/VectorOps/ops.mlir b/mlir/test/Dialect/VectorOps/ops.mlir index 1d28dea8282..69af80f46bc 100644 --- a/mlir/test/Dialect/VectorOps/ops.mlir +++ b/mlir/test/Dialect/VectorOps/ops.mlir @@ -53,6 +53,15 @@ func @shuffle2D(%a: vector<1x4xf32>, %b: vector<2x4xf32>) -> vector<3x4xf32> { return %1 : vector<3x4xf32> } +// CHECK-LABEL: @extract_element +func @extract_element(%a: vector<16xf32>) -> f32 { + // CHECK: %[[C15:.*]] = constant 15 : index + %c = constant 15 : index + // CHECK-NEXT: vector.extractelement %{{.*}}[%[[C15]] : index] : vector<16xf32> + %1 = vector.extractelement %a[%c : index] : vector<16xf32> + return %1 : f32 +} + // CHECK-LABEL: @extract func @extract(%arg0: vector<4x8x16xf32>) -> (vector<8x16xf32>, vector<16xf32>, f32) { // CHECK: vector.extract {{.*}}[3 : i32] : vector<4x8x16xf32> @@ -64,6 +73,15 @@ func @extract(%arg0: vector<4x8x16xf32>) -> (vector<8x16xf32>, vector<16xf32>, f return %1, %2, %3 : vector<8x16xf32>, vector<16xf32>, f32 } +// CHECK-LABEL: @insert_element +func @insert_element(%a: f32, %b: vector<16xf32>) -> vector<16xf32> { + // CHECK: %[[C15:.*]] = constant 15 : index + %c = constant 15 : index + // CHECK-NEXT: vector.insertelement %{{.*}}, %{{.*}}[%[[C15]] : index] : vector<16xf32> + %1 = vector.insertelement %a, %b[%c : index] : vector<16xf32> + return %1 : vector<16xf32> +} + // CHECK-LABEL: @insert func @insert(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>, %res: vector<4x8x16xf32>) -> vector<4x8x16xf32> { // CHECK: vector.insert %{{.*}}, %{{.*}}[3 : i32] : vector<8x16xf32> into vector<4x8x16xf32> |