diff options
Diffstat (limited to 'mlir/lib/Conversion/VectorToLLVM')
-rw-r--r-- | mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp | 53 |
1 files changed, 52 insertions, 1 deletions
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index d4c27a69fb5..71bed9516ca 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -300,6 +300,31 @@ public: } }; +class VectorExtractElementOpConversion : public LLVMOpLowering { +public: + explicit VectorExtractElementOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : LLVMOpLowering(vector::ExtractElementOp::getOperationName(), context, + typeConverter) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value *> operands, + ConversionPatternRewriter &rewriter) const override { + auto adaptor = vector::ExtractElementOpOperandAdaptor(operands); + auto extractEltOp = cast<vector::ExtractElementOp>(op); + auto vectorType = extractEltOp.getVectorType(); + auto llvmType = lowering.convertType(vectorType.getElementType()); + + // Bail if result type cannot be lowered. + if (!llvmType) + return matchFailure(); + + rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( + op, llvmType, adaptor.vector(), adaptor.position()); + return matchSuccess(); + } +}; + class VectorExtractOpConversion : public LLVMOpLowering { public: explicit VectorExtractOpConversion(MLIRContext *context, @@ -355,6 +380,31 @@ public: } }; +class VectorInsertElementOpConversion : public LLVMOpLowering { +public: + explicit VectorInsertElementOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : LLVMOpLowering(vector::InsertElementOp::getOperationName(), context, + typeConverter) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value *> operands, + ConversionPatternRewriter &rewriter) const override { + auto adaptor = vector::InsertElementOpOperandAdaptor(operands); + auto insertEltOp = cast<vector::InsertElementOp>(op); + auto vectorType = insertEltOp.getDestVectorType(); + auto llvmType = lowering.convertType(vectorType); + + // Bail if result type cannot be lowered. + if (!llvmType) + return matchFailure(); + + rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( + op, llvmType, adaptor.dest(), adaptor.source(), adaptor.position()); + return matchSuccess(); + } +}; + class VectorInsertOpConversion : public LLVMOpLowering { public: explicit VectorInsertOpConversion(MLIRContext *context, @@ -566,7 +616,8 @@ public: void mlir::populateVectorToLLVMConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { patterns.insert<VectorBroadcastOpConversion, VectorShuffleOpConversion, - VectorExtractOpConversion, VectorInsertOpConversion, + VectorExtractElementOpConversion, VectorExtractOpConversion, + VectorInsertElementOpConversion, VectorInsertOpConversion, VectorOuterProductOpConversion, VectorTypeCastOpConversion>( converter.getDialect()->getContext(), converter); } |