summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Conversion
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion')
-rw-r--r--mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp5
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp52
2 files changed, 19 insertions, 38 deletions
diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
index 570b6c4bcf2..e0edb0bc047 100644
--- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
@@ -344,6 +344,11 @@ void MemRefDescriptor::setPtr(OpBuilder &builder, Location loc, unsigned pos,
builder.getI64ArrayAttr(pos));
}
+LLVM::LLVMType MemRefDescriptor::getElementType() {
+ return value->getType().cast<LLVM::LLVMType>().getStructElementType(
+ LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor);
+}
+
namespace {
// Base class for Standard to LLVM IR op conversions. Matches the Op type
// provided as template argument. Carries a reference to the LLVM dialect in
diff --git a/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp
index 21bcdc9a6db..5bda8b3fd5b 100644
--- a/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp
@@ -177,22 +177,17 @@ public:
!targetMemRefType.hasStaticShape())
return matchFailure();
- Value *sourceMemRef = operands[0];
auto llvmSourceDescriptorTy =
- sourceMemRef->getType().dyn_cast<LLVM::LLVMType>();
+ operands[0]->getType().dyn_cast<LLVM::LLVMType>();
if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy())
return matchFailure();
+ MemRefDescriptor sourceMemRef(operands[0]);
auto llvmTargetDescriptorTy = lowering.convertType(targetMemRefType)
.dyn_cast_or_null<LLVM::LLVMType>();
if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
return matchFailure();
- Type llvmSourceElementTy = llvmSourceDescriptorTy.getStructElementType(
- LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor);
- Type llvmTargetElementTy = llvmTargetDescriptorTy.getStructElementType(
- LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor);
-
int64_t offset;
SmallVector<int64_t, 4> strides;
auto successStrides =
@@ -214,55 +209,36 @@ public:
auto int64Ty = LLVM::LLVMType::getInt64Ty(lowering.getDialect());
// Create descriptor.
- Value *desc = rewriter.create<LLVM::UndefOp>(loc, llvmTargetDescriptorTy);
+ auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
+ Type llvmTargetElementTy = desc.getElementType();
// Set allocated ptr.
- Value *allocated = rewriter.create<LLVM::ExtractValueOp>(
- loc, llvmSourceElementTy, sourceMemRef,
- rewriter.getIndexArrayAttr(
- LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor));
+ Value *allocated = sourceMemRef.allocatedPtr(rewriter, loc);
allocated =
rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
- desc = rewriter.create<LLVM::InsertValueOp>(
- op->getLoc(), llvmTargetDescriptorTy, desc, allocated,
- rewriter.getIndexArrayAttr(
- LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor));
- // Set ptr.
- Value *ptr = rewriter.create<LLVM::ExtractValueOp>(
- loc, llvmSourceElementTy, sourceMemRef,
- rewriter.getIndexArrayAttr(
- LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor));
+ desc.setAllocatedPtr(rewriter, loc, allocated);
+ // Set aligned ptr.
+ Value *ptr = sourceMemRef.alignedPtr(rewriter, loc);
ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
- desc = rewriter.create<LLVM::InsertValueOp>(
- op->getLoc(), llvmTargetDescriptorTy, desc, ptr,
- rewriter.getIndexArrayAttr(
- LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor));
+ desc.setAlignedPtr(rewriter, loc, ptr);
// Fill offset 0.
auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
- desc = rewriter.create<LLVM::InsertValueOp>(
- op->getLoc(), llvmTargetDescriptorTy, desc, zero,
- rewriter.getIndexArrayAttr(
- LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
+ desc.setOffset(rewriter, loc, zero);
+
// Fill size and stride descriptors in memref.
for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) {
int64_t index = indexedSize.index();
auto sizeAttr =
rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
- desc = rewriter.create<LLVM::InsertValueOp>(
- op->getLoc(), llvmTargetDescriptorTy, desc, size,
- rewriter.getI64ArrayAttr(
- {LLVMTypeConverter::kSizePosInMemRefDescriptor, index}));
+ desc.setSize(rewriter, loc, index, size);
auto strideAttr =
rewriter.getIntegerAttr(rewriter.getIndexType(), strides[index]);
auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
- desc = rewriter.create<LLVM::InsertValueOp>(
- op->getLoc(), llvmTargetDescriptorTy, desc, stride,
- rewriter.getI64ArrayAttr(
- {LLVMTypeConverter::kStridePosInMemRefDescriptor, index}));
+ desc.setStride(rewriter, loc, index, stride);
}
- rewriter.replaceOp(op, desc);
+ rewriter.replaceOp(op, {desc});
return matchSuccess();
}
};
OpenPOWER on IntegriCloud