diff options
Diffstat (limited to 'mlir/lib/Conversion/VectorToLLVM')
-rw-r--r-- | mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp | 124 |
1 files changed, 100 insertions, 24 deletions
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 8adc415f820..17fb93396d9 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -49,6 +49,18 @@ static LLVM::LLVMType getPtrToElementType(T containerType, .getPointerTo(); } +// Helper to reduce vector type by one rank at front. +static VectorType reducedVectorTypeFront(VectorType tp) { + assert((tp.getRank() > 1) && "unlowerable vector type"); + return VectorType::get(tp.getShape().drop_front(), tp.getElementType()); +} + +// Helper to reduce vector type by *all* but one rank at back. +static VectorType reducedVectorTypeBack(VectorType tp) { + assert((tp.getRank() > 1) && "unlowerable vector type"); + return VectorType::get(tp.getShape().take_back(), tp.getElementType()); +} + class VectorBroadcastOpConversion : public LLVMOpLowering { public: explicit VectorBroadcastOpConversion(MLIRContext *context, @@ -135,8 +147,9 @@ private: return rewriter.create<LLVM::ShuffleVectorOp>( loc, expand, undef, rewriter.getI32ArrayAttr(zeroValues)); } - Value *expand = expandRanks(value, loc, srcVectorType, - reducedVectorType(dstVectorType), rewriter); + Value *expand = + expandRanks(value, loc, srcVectorType, + reducedVectorTypeFront(dstVectorType), rewriter); Value *result = rewriter.create<LLVM::UndefOp>(loc, llvmType); for (int64_t d = 0; d < dim; ++d) { result = insertOne(result, expand, loc, llvmType, rank, d, rewriter); @@ -183,8 +196,8 @@ private: result = insertOne(result, one, loc, llvmType, rank, d, rewriter); } } else { - VectorType redSrcType = reducedVectorType(srcVectorType); - VectorType redDstType = reducedVectorType(dstVectorType); + VectorType redSrcType = reducedVectorTypeFront(srcVectorType); + VectorType redDstType = reducedVectorTypeFront(dstVectorType); Type redLlvmType = lowering.convertType(redSrcType); for (int64_t d = 0; d < dim; ++d) { int64_t pos = atStretch ? 0 : d; @@ -226,18 +239,12 @@ private: return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, value, rewriter.getI64ArrayAttr(pos)); } - - // Helper to reduce vector type by one rank. - static VectorType reducedVectorType(VectorType tp) { - assert((tp.getRank() > 1) && "unlowerable vector type"); - return VectorType::get(tp.getShape().drop_front(), tp.getElementType()); - } }; -class VectorExtractElementOpConversion : public LLVMOpLowering { +class VectorExtractOpConversion : public LLVMOpLowering { public: - explicit VectorExtractElementOpConversion(MLIRContext *context, - LLVMTypeConverter &typeConverter) + explicit VectorExtractOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) : LLVMOpLowering(vector::ExtractOp::getOperationName(), context, typeConverter) {} @@ -247,11 +254,15 @@ public: auto loc = op->getLoc(); auto adaptor = vector::ExtractOpOperandAdaptor(operands); auto extractOp = cast<vector::ExtractOp>(op); - auto vectorType = extractOp.vector()->getType().cast<VectorType>(); + auto vectorType = extractOp.getVectorType(); auto resultType = extractOp.getResult()->getType(); auto llvmResultType = lowering.convertType(resultType); - auto positionArrayAttr = extractOp.position(); + + // Bail if result type cannot be lowered. + if (!llvmResultType) + return matchFailure(); + // One-shot extraction of vector from array (only requires extractvalue). if (resultType.isa<VectorType>()) { Value *extracted = rewriter.create<LLVM::ExtractValueOp>( @@ -260,15 +271,12 @@ public: return matchSuccess(); } - // Potential extraction of 1-D vector from struct. + // Potential extraction of 1-D vector from array. auto *context = op->getContext(); Value *extracted = adaptor.vector(); auto positionAttrs = positionArrayAttr.getValue(); - auto i32Type = rewriter.getIntegerType(32); if (positionAttrs.size() > 1) { - auto nDVectorType = vectorType; - auto oneDVectorType = VectorType::get(nDVectorType.getShape().take_back(), - nDVectorType.getElementType()); + auto oneDVectorType = reducedVectorTypeBack(vectorType); auto nMinusOnePositionAttrs = ArrayAttr::get(positionAttrs.drop_back(), context); extracted = rewriter.create<LLVM::ExtractValueOp>( @@ -278,8 +286,8 @@ public: // Remaining extraction of element from 1-D LLVM vector auto position = positionAttrs.back().cast<IntegerAttr>(); - auto constant = rewriter.create<LLVM::ConstantOp>( - loc, lowering.convertType(i32Type), position); + auto i32Type = LLVM::LLVMType::getInt32Ty(lowering.getDialect()); + auto constant = rewriter.create<LLVM::ConstantOp>(loc, i32Type, position); extracted = rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant); rewriter.replaceOp(op, extracted); @@ -288,6 +296,73 @@ public: } }; +class VectorInsertOpConversion : public LLVMOpLowering { +public: + explicit VectorInsertOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : LLVMOpLowering(vector::InsertOp::getOperationName(), context, + typeConverter) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value *> operands, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto adaptor = vector::InsertOpOperandAdaptor(operands); + auto insertOp = cast<vector::InsertOp>(op); + auto sourceType = insertOp.getSourceType(); + auto destVectorType = insertOp.getDestVectorType(); + auto llvmResultType = lowering.convertType(destVectorType); + auto positionArrayAttr = insertOp.position(); + + // Bail if result type cannot be lowered. + if (!llvmResultType) + return matchFailure(); + + // One-shot insertion of a vector into an array (only requires insertvalue). + if (sourceType.isa<VectorType>()) { + Value *inserted = rewriter.create<LLVM::InsertValueOp>( + loc, llvmResultType, adaptor.dest(), adaptor.source(), + positionArrayAttr); + rewriter.replaceOp(op, inserted); + return matchSuccess(); + } + + // Potential extraction of 1-D vector from array. + auto *context = op->getContext(); + Value *extracted = adaptor.dest(); + auto positionAttrs = positionArrayAttr.getValue(); + auto position = positionAttrs.back().cast<IntegerAttr>(); + auto oneDVectorType = destVectorType; + if (positionAttrs.size() > 1) { + oneDVectorType = reducedVectorTypeBack(destVectorType); + auto nMinusOnePositionAttrs = + ArrayAttr::get(positionAttrs.drop_back(), context); + extracted = rewriter.create<LLVM::ExtractValueOp>( + loc, lowering.convertType(oneDVectorType), extracted, + nMinusOnePositionAttrs); + } + + // Insertion of an element into a 1-D LLVM vector. + auto i32Type = LLVM::LLVMType::getInt32Ty(lowering.getDialect()); + auto constant = rewriter.create<LLVM::ConstantOp>(loc, i32Type, position); + Value *inserted = rewriter.create<LLVM::InsertElementOp>( + loc, lowering.convertType(oneDVectorType), extracted, adaptor.source(), + constant); + + // Potential insertion of resulting 1-D vector into array. + if (positionAttrs.size() > 1) { + auto nMinusOnePositionAttrs = + ArrayAttr::get(positionAttrs.drop_back(), context); + inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType, + adaptor.dest(), inserted, + nMinusOnePositionAttrs); + } + + rewriter.replaceOp(op, inserted); + return matchSuccess(); + } +}; + class VectorOuterProductOpConversion : public LLVMOpLowering { public: explicit VectorOuterProductOpConversion(MLIRContext *context, @@ -431,8 +506,9 @@ public: /// Populate the given list with patterns that convert from Vector to LLVM. void mlir::populateVectorToLLVMConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { - patterns.insert<VectorBroadcastOpConversion, VectorExtractElementOpConversion, - VectorOuterProductOpConversion, VectorTypeCastOpConversion>( + patterns.insert<VectorBroadcastOpConversion, VectorExtractOpConversion, + VectorInsertOpConversion, VectorOuterProductOpConversion, + VectorTypeCastOpConversion>( converter.getDialect()->getContext(), converter); } |