diff options
Diffstat (limited to 'mlir/lib/Conversion/VectorToLLVM')
-rw-r--r-- | mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp | 106 |
1 files changed, 104 insertions, 2 deletions
diff --git a/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp index 3c3a18d80e7..765c25ae227 100644 --- a/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp @@ -155,10 +155,112 @@ public: } }; +class VectorTypeCastOpConversion : public LLVMOpLowering { +public: + explicit VectorTypeCastOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : LLVMOpLowering(vector::VectorTypeCastOp::getOperationName(), context, + typeConverter) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value *> operands, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + vector::VectorTypeCastOp castOp = cast<vector::VectorTypeCastOp>(op); + MemRefType sourceMemRefType = + castOp.getOperand()->getType().cast<MemRefType>(); + MemRefType targetMemRefType = + castOp.getResult()->getType().cast<MemRefType>(); + + // Only static shape casts supported atm. + if (!sourceMemRefType.hasStaticShape() || + !targetMemRefType.hasStaticShape()) + return matchFailure(); + + Value *sourceMemRef = operands[0]; + auto llvmSourceDescriptorTy = + sourceMemRef->getType().dyn_cast<LLVM::LLVMType>(); + if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy()) + return matchFailure(); + + auto llvmTargetDescriptorTy = lowering.convertType(targetMemRefType) + .dyn_cast_or_null<LLVM::LLVMType>(); + if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy()) + return matchFailure(); + + Type llvmSourceElementTy = llvmSourceDescriptorTy.getStructElementType( + LLVMTypeConverter::kPtrPosInMemRefDescriptor); + Type llvmTargetElementTy = llvmTargetDescriptorTy.getStructElementType( + LLVMTypeConverter::kPtrPosInMemRefDescriptor); + + int64_t offset; + SmallVector<int64_t, 4> strides; + auto successStrides = + getStridesAndOffset(targetMemRefType, strides, offset); + bool isContiguous = (strides.back() == 1); + if (isContiguous) { + auto sizes = targetMemRefType.getShape(); + for (int index = 0, e = strides.size() - 2; index < e; ++index) { + if (strides[index] != strides[index + 1] * sizes[index + 1]) { + isContiguous = false; + break; + } + } + } + // Only contiguous tensors supported atm. + if (failed(successStrides) || !isContiguous) + return matchFailure(); + + auto int64Ty = LLVM::LLVMType::getInt64Ty(lowering.getDialect()); + + // Create descriptor. + Value *desc = rewriter.create<LLVM::UndefOp>(loc, llvmTargetDescriptorTy); + // Set ptr. + Value *ptr = rewriter.create<LLVM::ExtractValueOp>( + loc, llvmSourceElementTy, sourceMemRef, + rewriter.getIndexArrayAttr( + LLVMTypeConverter::kPtrPosInMemRefDescriptor)); + ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr); + desc = rewriter.create<LLVM::InsertValueOp>( + op->getLoc(), llvmTargetDescriptorTy, desc, ptr, + rewriter.getIndexArrayAttr( + LLVMTypeConverter::kPtrPosInMemRefDescriptor)); + // Fill offset 0. + auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0); + auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr); + desc = rewriter.create<LLVM::InsertValueOp>( + op->getLoc(), llvmTargetDescriptorTy, desc, zero, + rewriter.getIndexArrayAttr( + LLVMTypeConverter::kOffsetPosInMemRefDescriptor)); + // Fill size and stride descriptors in memref. + for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) { + int64_t index = indexedSize.index(); + auto sizeAttr = + rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value()); + auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr); + desc = rewriter.create<LLVM::InsertValueOp>( + op->getLoc(), llvmTargetDescriptorTy, desc, size, + rewriter.getI64ArrayAttr( + {LLVMTypeConverter::kSizePosInMemRefDescriptor, index})); + auto strideAttr = + rewriter.getIntegerAttr(rewriter.getIndexType(), strides[index]); + auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr); + desc = rewriter.create<LLVM::InsertValueOp>( + op->getLoc(), llvmTargetDescriptorTy, desc, stride, + rewriter.getI64ArrayAttr( + {LLVMTypeConverter::kStridePosInMemRefDescriptor, index})); + } + + rewriter.replaceOp(op, desc); + return matchSuccess(); + } +}; + /// Populate the given list with patterns that convert from Vector to LLVM. void mlir::populateVectorToLLVMConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { - patterns.insert<ExtractElementOpConversion, OuterProductOpConversion>( + patterns.insert<ExtractElementOpConversion, OuterProductOpConversion, + VectorTypeCastOpConversion>( converter.getDialect()->getContext(), converter); } @@ -190,5 +292,5 @@ OpPassBase<ModuleOp> *mlir::createLowerVectorToLLVMPass() { } static PassRegistration<LowerVectorToLLVMPass> - pass("vector-lower-to-llvm-dialect", + pass("convert-vector-to-llvm", "Lower the operations from the vector dialect into the LLVM dialect"); |