summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp')
-rw-r--r--mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp35
1 files changed, 22 insertions, 13 deletions
diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
index 15f61ab9ce8..490b6695d84 100644
--- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
@@ -156,10 +156,10 @@ LLVM::LLVMType LLVMTypeConverter::convertFunctionSignature(
// int64_t sizes[Rank]; // omitted when rank == 0
// int64_t strides[Rank]; // omitted when rank == 0
// };
-static unsigned kPtrPosInMemRefDescriptor = 0;
-static unsigned kOffsetPosInMemRefDescriptor = 1;
-static unsigned kSizePosInMemRefDescriptor = 2;
-static unsigned kStridePosInMemRefDescriptor = 3;
+constexpr unsigned LLVMTypeConverter::kPtrPosInMemRefDescriptor;
+constexpr unsigned LLVMTypeConverter::kOffsetPosInMemRefDescriptor;
+constexpr unsigned LLVMTypeConverter::kSizePosInMemRefDescriptor;
+constexpr unsigned LLVMTypeConverter::kStridePosInMemRefDescriptor;
Type LLVMTypeConverter::convertMemRefType(MemRefType type) {
int64_t offset;
SmallVector<int64_t, 4> strides;
@@ -282,7 +282,8 @@ public:
Type elementTypePtr) {
return builder.create<LLVM::ExtractValueOp>(
loc, elementTypePtr, memref,
- builder.getIndexArrayAttr(kPtrPosInMemRefDescriptor));
+ builder.getIndexArrayAttr(
+ LLVMTypeConverter::kPtrPosInMemRefDescriptor));
}
protected:
@@ -763,11 +764,13 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
memRefDescriptor = rewriter.create<LLVM::InsertValueOp>(
op->getLoc(), structType, memRefDescriptor, allocated,
- rewriter.getIndexArrayAttr(kPtrPosInMemRefDescriptor));
+ rewriter.getIndexArrayAttr(
+ LLVMTypeConverter::kPtrPosInMemRefDescriptor));
memRefDescriptor = rewriter.create<LLVM::InsertValueOp>(
op->getLoc(), structType, memRefDescriptor,
createIndexConstant(rewriter, op->getLoc(), offset),
- rewriter.getIndexArrayAttr(kOffsetPosInMemRefDescriptor));
+ rewriter.getIndexArrayAttr(
+ LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
if (type.getRank() == 0)
// No size/stride descriptor in memref, return the descriptor value.
@@ -798,10 +801,12 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
int64_t index = indexedSize.index();
memRefDescriptor = rewriter.create<LLVM::InsertValueOp>(
op->getLoc(), structType, memRefDescriptor, indexedSize.value(),
- rewriter.getI64ArrayAttr({kSizePosInMemRefDescriptor, index}));
+ rewriter.getI64ArrayAttr(
+ {LLVMTypeConverter::kSizePosInMemRefDescriptor, index}));
memRefDescriptor = rewriter.create<LLVM::InsertValueOp>(
op->getLoc(), structType, memRefDescriptor, strideValues[index],
- rewriter.getI64ArrayAttr({kStridePosInMemRefDescriptor, index}));
+ rewriter.getI64ArrayAttr(
+ {LLVMTypeConverter::kStridePosInMemRefDescriptor, index}));
}
// Return the final value of the descriptor.
@@ -896,7 +901,8 @@ struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
}
auto type = transformed.memref()->getType().cast<LLVM::LLVMType>();
- Type elementPtrType = type.getStructElementType(kPtrPosInMemRefDescriptor);
+ Type elementPtrType =
+ type.getStructElementType(LLVMTypeConverter::kPtrPosInMemRefDescriptor);
Value *bufferPtr = extractMemRefElementPtr(
rewriter, op->getLoc(), transformed.memref(), elementPtrType);
Value *casted = rewriter.create<LLVM::BitcastOp>(
@@ -952,7 +958,8 @@ struct DimOpLowering : public LLVMLegalizationPattern<DimOp> {
if (ShapedType::isDynamic(shape[index]))
rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(
op, getIndexType(), transformed.memrefOrTensor(),
- rewriter.getI64ArrayAttr({kSizePosInMemRefDescriptor, index}));
+ rewriter.getI64ArrayAttr(
+ {LLVMTypeConverter::kSizePosInMemRefDescriptor, index}));
else
// Use constant for static size.
rewriter.replaceOp(
@@ -1015,7 +1022,8 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> {
offset == MemRefType::getDynamicStrideOrOffset()
? rewriter.create<LLVM::ExtractValueOp>(
loc, indexTy, memRefDescriptor,
- rewriter.getIndexArrayAttr(kOffsetPosInMemRefDescriptor))
+ rewriter.getIndexArrayAttr(
+ LLVMTypeConverter::kOffsetPosInMemRefDescriptor))
: this->createIndexConstant(rewriter, loc, offset);
for (int i = 0, e = indices.size(); i < e; ++i) {
Value *stride;
@@ -1028,7 +1036,8 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> {
// Use dynamic stride.
stride = rewriter.create<LLVM::ExtractValueOp>(
loc, indexTy, memRefDescriptor,
- rewriter.getIndexArrayAttr({kStridePosInMemRefDescriptor, i}));
+ rewriter.getIndexArrayAttr(
+ {LLVMTypeConverter::kStridePosInMemRefDescriptor, i}));
}
Value *additionalOffset =
rewriter.create<LLVM::MulOp>(loc, indices[i], stride);
OpenPOWER on IntegriCloud