summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Conversion
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion')
-rw-r--r--mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp35
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp106
2 files changed, 126 insertions, 15 deletions
diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
index 15f61ab9ce8..490b6695d84 100644
--- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
@@ -156,10 +156,10 @@ LLVM::LLVMType LLVMTypeConverter::convertFunctionSignature(
// int64_t sizes[Rank]; // omitted when rank == 0
// int64_t strides[Rank]; // omitted when rank == 0
// };
-static unsigned kPtrPosInMemRefDescriptor = 0;
-static unsigned kOffsetPosInMemRefDescriptor = 1;
-static unsigned kSizePosInMemRefDescriptor = 2;
-static unsigned kStridePosInMemRefDescriptor = 3;
+constexpr unsigned LLVMTypeConverter::kPtrPosInMemRefDescriptor;
+constexpr unsigned LLVMTypeConverter::kOffsetPosInMemRefDescriptor;
+constexpr unsigned LLVMTypeConverter::kSizePosInMemRefDescriptor;
+constexpr unsigned LLVMTypeConverter::kStridePosInMemRefDescriptor;
Type LLVMTypeConverter::convertMemRefType(MemRefType type) {
int64_t offset;
SmallVector<int64_t, 4> strides;
@@ -282,7 +282,8 @@ public:
Type elementTypePtr) {
return builder.create<LLVM::ExtractValueOp>(
loc, elementTypePtr, memref,
- builder.getIndexArrayAttr(kPtrPosInMemRefDescriptor));
+ builder.getIndexArrayAttr(
+ LLVMTypeConverter::kPtrPosInMemRefDescriptor));
}
protected:
@@ -763,11 +764,13 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
memRefDescriptor = rewriter.create<LLVM::InsertValueOp>(
op->getLoc(), structType, memRefDescriptor, allocated,
- rewriter.getIndexArrayAttr(kPtrPosInMemRefDescriptor));
+ rewriter.getIndexArrayAttr(
+ LLVMTypeConverter::kPtrPosInMemRefDescriptor));
memRefDescriptor = rewriter.create<LLVM::InsertValueOp>(
op->getLoc(), structType, memRefDescriptor,
createIndexConstant(rewriter, op->getLoc(), offset),
- rewriter.getIndexArrayAttr(kOffsetPosInMemRefDescriptor));
+ rewriter.getIndexArrayAttr(
+ LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
if (type.getRank() == 0)
// No size/stride descriptor in memref, return the descriptor value.
@@ -798,10 +801,12 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
int64_t index = indexedSize.index();
memRefDescriptor = rewriter.create<LLVM::InsertValueOp>(
op->getLoc(), structType, memRefDescriptor, indexedSize.value(),
- rewriter.getI64ArrayAttr({kSizePosInMemRefDescriptor, index}));
+ rewriter.getI64ArrayAttr(
+ {LLVMTypeConverter::kSizePosInMemRefDescriptor, index}));
memRefDescriptor = rewriter.create<LLVM::InsertValueOp>(
op->getLoc(), structType, memRefDescriptor, strideValues[index],
- rewriter.getI64ArrayAttr({kStridePosInMemRefDescriptor, index}));
+ rewriter.getI64ArrayAttr(
+ {LLVMTypeConverter::kStridePosInMemRefDescriptor, index}));
}
// Return the final value of the descriptor.
@@ -896,7 +901,8 @@ struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
}
auto type = transformed.memref()->getType().cast<LLVM::LLVMType>();
- Type elementPtrType = type.getStructElementType(kPtrPosInMemRefDescriptor);
+ Type elementPtrType =
+ type.getStructElementType(LLVMTypeConverter::kPtrPosInMemRefDescriptor);
Value *bufferPtr = extractMemRefElementPtr(
rewriter, op->getLoc(), transformed.memref(), elementPtrType);
Value *casted = rewriter.create<LLVM::BitcastOp>(
@@ -952,7 +958,8 @@ struct DimOpLowering : public LLVMLegalizationPattern<DimOp> {
if (ShapedType::isDynamic(shape[index]))
rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(
op, getIndexType(), transformed.memrefOrTensor(),
- rewriter.getI64ArrayAttr({kSizePosInMemRefDescriptor, index}));
+ rewriter.getI64ArrayAttr(
+ {LLVMTypeConverter::kSizePosInMemRefDescriptor, index}));
else
// Use constant for static size.
rewriter.replaceOp(
@@ -1015,7 +1022,8 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> {
offset == MemRefType::getDynamicStrideOrOffset()
? rewriter.create<LLVM::ExtractValueOp>(
loc, indexTy, memRefDescriptor,
- rewriter.getIndexArrayAttr(kOffsetPosInMemRefDescriptor))
+ rewriter.getIndexArrayAttr(
+ LLVMTypeConverter::kOffsetPosInMemRefDescriptor))
: this->createIndexConstant(rewriter, loc, offset);
for (int i = 0, e = indices.size(); i < e; ++i) {
Value *stride;
@@ -1028,7 +1036,8 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> {
// Use dynamic stride.
stride = rewriter.create<LLVM::ExtractValueOp>(
loc, indexTy, memRefDescriptor,
- rewriter.getIndexArrayAttr({kStridePosInMemRefDescriptor, i}));
+ rewriter.getIndexArrayAttr(
+ {LLVMTypeConverter::kStridePosInMemRefDescriptor, i}));
}
Value *additionalOffset =
rewriter.create<LLVM::MulOp>(loc, indices[i], stride);
diff --git a/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp
index 3c3a18d80e7..765c25ae227 100644
--- a/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp
@@ -155,10 +155,112 @@ public:
}
};
+class VectorTypeCastOpConversion : public LLVMOpLowering {
+public:
+ explicit VectorTypeCastOpConversion(MLIRContext *context,
+ LLVMTypeConverter &typeConverter)
+ : LLVMOpLowering(vector::VectorTypeCastOp::getOperationName(), context,
+ typeConverter) {}
+
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op->getLoc();
+ vector::VectorTypeCastOp castOp = cast<vector::VectorTypeCastOp>(op);
+ MemRefType sourceMemRefType =
+ castOp.getOperand()->getType().cast<MemRefType>();
+ MemRefType targetMemRefType =
+ castOp.getResult()->getType().cast<MemRefType>();
+
+ // Only static shape casts supported atm.
+ if (!sourceMemRefType.hasStaticShape() ||
+ !targetMemRefType.hasStaticShape())
+ return matchFailure();
+
+ Value *sourceMemRef = operands[0];
+ auto llvmSourceDescriptorTy =
+ sourceMemRef->getType().dyn_cast<LLVM::LLVMType>();
+ if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy())
+ return matchFailure();
+
+ auto llvmTargetDescriptorTy = lowering.convertType(targetMemRefType)
+ .dyn_cast_or_null<LLVM::LLVMType>();
+ if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
+ return matchFailure();
+
+ Type llvmSourceElementTy = llvmSourceDescriptorTy.getStructElementType(
+ LLVMTypeConverter::kPtrPosInMemRefDescriptor);
+ Type llvmTargetElementTy = llvmTargetDescriptorTy.getStructElementType(
+ LLVMTypeConverter::kPtrPosInMemRefDescriptor);
+
+ int64_t offset;
+ SmallVector<int64_t, 4> strides;
+ auto successStrides =
+ getStridesAndOffset(targetMemRefType, strides, offset);
+ bool isContiguous = (strides.back() == 1);
+ if (isContiguous) {
+ auto sizes = targetMemRefType.getShape();
+ for (int index = 0, e = strides.size() - 2; index < e; ++index) {
+ if (strides[index] != strides[index + 1] * sizes[index + 1]) {
+ isContiguous = false;
+ break;
+ }
+ }
+ }
+ // Only contiguous tensors supported atm.
+ if (failed(successStrides) || !isContiguous)
+ return matchFailure();
+
+ auto int64Ty = LLVM::LLVMType::getInt64Ty(lowering.getDialect());
+
+ // Create descriptor.
+ Value *desc = rewriter.create<LLVM::UndefOp>(loc, llvmTargetDescriptorTy);
+ // Set ptr.
+ Value *ptr = rewriter.create<LLVM::ExtractValueOp>(
+ loc, llvmSourceElementTy, sourceMemRef,
+ rewriter.getIndexArrayAttr(
+ LLVMTypeConverter::kPtrPosInMemRefDescriptor));
+ ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
+ desc = rewriter.create<LLVM::InsertValueOp>(
+ op->getLoc(), llvmTargetDescriptorTy, desc, ptr,
+ rewriter.getIndexArrayAttr(
+ LLVMTypeConverter::kPtrPosInMemRefDescriptor));
+ // Fill offset 0.
+ auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
+ auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
+ desc = rewriter.create<LLVM::InsertValueOp>(
+ op->getLoc(), llvmTargetDescriptorTy, desc, zero,
+ rewriter.getIndexArrayAttr(
+ LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
+ // Fill size and stride descriptors in memref.
+ for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) {
+ int64_t index = indexedSize.index();
+ auto sizeAttr =
+ rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
+ auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
+ desc = rewriter.create<LLVM::InsertValueOp>(
+ op->getLoc(), llvmTargetDescriptorTy, desc, size,
+ rewriter.getI64ArrayAttr(
+ {LLVMTypeConverter::kSizePosInMemRefDescriptor, index}));
+ auto strideAttr =
+ rewriter.getIntegerAttr(rewriter.getIndexType(), strides[index]);
+ auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
+ desc = rewriter.create<LLVM::InsertValueOp>(
+ op->getLoc(), llvmTargetDescriptorTy, desc, stride,
+ rewriter.getI64ArrayAttr(
+ {LLVMTypeConverter::kStridePosInMemRefDescriptor, index}));
+ }
+
+ rewriter.replaceOp(op, desc);
+ return matchSuccess();
+ }
+};
+
/// Populate the given list with patterns that convert from Vector to LLVM.
void mlir::populateVectorToLLVMConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
- patterns.insert<ExtractElementOpConversion, OuterProductOpConversion>(
+ patterns.insert<ExtractElementOpConversion, OuterProductOpConversion,
+ VectorTypeCastOpConversion>(
converter.getDialect()->getContext(), converter);
}
@@ -190,5 +292,5 @@ OpPassBase<ModuleOp> *mlir::createLowerVectorToLLVMPass() {
}
static PassRegistration<LowerVectorToLLVMPass>
- pass("vector-lower-to-llvm-dialect",
+ pass("convert-vector-to-llvm",
"Lower the operations from the vector dialect into the LLVM dialect");
OpenPOWER on IntegriCloud