summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Conversion/VectorToLLVM
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion/VectorToLLVM')
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp106
1 files changed, 104 insertions, 2 deletions
diff --git a/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp
index 3c3a18d80e7..765c25ae227 100644
--- a/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp
@@ -155,10 +155,112 @@ public:
}
};
+class VectorTypeCastOpConversion : public LLVMOpLowering {
+public:
+ explicit VectorTypeCastOpConversion(MLIRContext *context,
+ LLVMTypeConverter &typeConverter)
+ : LLVMOpLowering(vector::VectorTypeCastOp::getOperationName(), context,
+ typeConverter) {}
+
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op->getLoc();
+ vector::VectorTypeCastOp castOp = cast<vector::VectorTypeCastOp>(op);
+ MemRefType sourceMemRefType =
+ castOp.getOperand()->getType().cast<MemRefType>();
+ MemRefType targetMemRefType =
+ castOp.getResult()->getType().cast<MemRefType>();
+
+ // Only static shape casts supported atm.
+ if (!sourceMemRefType.hasStaticShape() ||
+ !targetMemRefType.hasStaticShape())
+ return matchFailure();
+
+ Value *sourceMemRef = operands[0];
+ auto llvmSourceDescriptorTy =
+ sourceMemRef->getType().dyn_cast<LLVM::LLVMType>();
+ if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy())
+ return matchFailure();
+
+ auto llvmTargetDescriptorTy = lowering.convertType(targetMemRefType)
+ .dyn_cast_or_null<LLVM::LLVMType>();
+ if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
+ return matchFailure();
+
+ Type llvmSourceElementTy = llvmSourceDescriptorTy.getStructElementType(
+ LLVMTypeConverter::kPtrPosInMemRefDescriptor);
+ Type llvmTargetElementTy = llvmTargetDescriptorTy.getStructElementType(
+ LLVMTypeConverter::kPtrPosInMemRefDescriptor);
+
+ int64_t offset;
+ SmallVector<int64_t, 4> strides;
+ auto successStrides =
+ getStridesAndOffset(targetMemRefType, strides, offset);
+ bool isContiguous = (strides.back() == 1);
+ if (isContiguous) {
+ auto sizes = targetMemRefType.getShape();
+ for (int index = 0, e = strides.size() - 2; index < e; ++index) {
+ if (strides[index] != strides[index + 1] * sizes[index + 1]) {
+ isContiguous = false;
+ break;
+ }
+ }
+ }
+ // Only contiguous tensors supported atm.
+ if (failed(successStrides) || !isContiguous)
+ return matchFailure();
+
+ auto int64Ty = LLVM::LLVMType::getInt64Ty(lowering.getDialect());
+
+ // Create descriptor.
+ Value *desc = rewriter.create<LLVM::UndefOp>(loc, llvmTargetDescriptorTy);
+ // Set ptr.
+ Value *ptr = rewriter.create<LLVM::ExtractValueOp>(
+ loc, llvmSourceElementTy, sourceMemRef,
+ rewriter.getIndexArrayAttr(
+ LLVMTypeConverter::kPtrPosInMemRefDescriptor));
+ ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
+ desc = rewriter.create<LLVM::InsertValueOp>(
+ op->getLoc(), llvmTargetDescriptorTy, desc, ptr,
+ rewriter.getIndexArrayAttr(
+ LLVMTypeConverter::kPtrPosInMemRefDescriptor));
+ // Fill offset 0.
+ auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
+ auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
+ desc = rewriter.create<LLVM::InsertValueOp>(
+ op->getLoc(), llvmTargetDescriptorTy, desc, zero,
+ rewriter.getIndexArrayAttr(
+ LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
+ // Fill size and stride descriptors in memref.
+ for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) {
+ int64_t index = indexedSize.index();
+ auto sizeAttr =
+ rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
+ auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
+ desc = rewriter.create<LLVM::InsertValueOp>(
+ op->getLoc(), llvmTargetDescriptorTy, desc, size,
+ rewriter.getI64ArrayAttr(
+ {LLVMTypeConverter::kSizePosInMemRefDescriptor, index}));
+ auto strideAttr =
+ rewriter.getIntegerAttr(rewriter.getIndexType(), strides[index]);
+ auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
+ desc = rewriter.create<LLVM::InsertValueOp>(
+ op->getLoc(), llvmTargetDescriptorTy, desc, stride,
+ rewriter.getI64ArrayAttr(
+ {LLVMTypeConverter::kStridePosInMemRefDescriptor, index}));
+ }
+
+ rewriter.replaceOp(op, desc);
+ return matchSuccess();
+ }
+};
+
/// Populate the given list with patterns that convert from Vector to LLVM.
void mlir::populateVectorToLLVMConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
- patterns.insert<ExtractElementOpConversion, OuterProductOpConversion>(
+ patterns.insert<ExtractElementOpConversion, OuterProductOpConversion,
+ VectorTypeCastOpConversion>(
converter.getDialect()->getContext(), converter);
}
@@ -190,5 +292,5 @@ OpPassBase<ModuleOp> *mlir::createLowerVectorToLLVMPass() {
}
static PassRegistration<LowerVectorToLLVMPass>
- pass("vector-lower-to-llvm-dialect",
+ pass("convert-vector-to-llvm",
"Lower the operations from the vector dialect into the LLVM dialect");
OpenPOWER on IntegriCloud