diff options
Diffstat (limited to 'mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp')
| -rw-r--r-- | mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp | 53 | 
1 files changed, 52 insertions, 1 deletions
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index d4c27a69fb5..71bed9516ca 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -300,6 +300,31 @@ public:    }  }; +class VectorExtractElementOpConversion : public LLVMOpLowering { +public: +  explicit VectorExtractElementOpConversion(MLIRContext *context, +                                            LLVMTypeConverter &typeConverter) +      : LLVMOpLowering(vector::ExtractElementOp::getOperationName(), context, +                       typeConverter) {} + +  PatternMatchResult +  matchAndRewrite(Operation *op, ArrayRef<Value *> operands, +                  ConversionPatternRewriter &rewriter) const override { +    auto adaptor = vector::ExtractElementOpOperandAdaptor(operands); +    auto extractEltOp = cast<vector::ExtractElementOp>(op); +    auto vectorType = extractEltOp.getVectorType(); +    auto llvmType = lowering.convertType(vectorType.getElementType()); + +    // Bail if result type cannot be lowered. +    if (!llvmType) +      return matchFailure(); + +    rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( +        op, llvmType, adaptor.vector(), adaptor.position()); +    return matchSuccess(); +  } +}; +  class VectorExtractOpConversion : public LLVMOpLowering {  public:    explicit VectorExtractOpConversion(MLIRContext *context, @@ -355,6 +380,31 @@ public:    }  }; +class VectorInsertElementOpConversion : public LLVMOpLowering { +public: +  explicit VectorInsertElementOpConversion(MLIRContext *context, +                                           LLVMTypeConverter &typeConverter) +      : LLVMOpLowering(vector::InsertElementOp::getOperationName(), context, +                       typeConverter) {} + +  PatternMatchResult +  matchAndRewrite(Operation *op, ArrayRef<Value *> operands, +                  ConversionPatternRewriter &rewriter) const override { +    auto adaptor = vector::InsertElementOpOperandAdaptor(operands); +    auto insertEltOp = cast<vector::InsertElementOp>(op); +    auto vectorType = insertEltOp.getDestVectorType(); +    auto llvmType = lowering.convertType(vectorType); + +    // Bail if result type cannot be lowered. +    if (!llvmType) +      return matchFailure(); + +    rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( +        op, llvmType, adaptor.dest(), adaptor.source(), adaptor.position()); +    return matchSuccess(); +  } +}; +  class VectorInsertOpConversion : public LLVMOpLowering {  public:    explicit VectorInsertOpConversion(MLIRContext *context, @@ -566,7 +616,8 @@ public:  void mlir::populateVectorToLLVMConversionPatterns(      LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {    patterns.insert<VectorBroadcastOpConversion, VectorShuffleOpConversion, -                  VectorExtractOpConversion, VectorInsertOpConversion, +                  VectorExtractElementOpConversion, VectorExtractOpConversion, +                  VectorInsertElementOpConversion, VectorInsertOpConversion,                    VectorOuterProductOpConversion, VectorTypeCastOpConversion>(        converter.getDialect()->getContext(), converter);  }  | 

