diff options
Diffstat (limited to 'mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp')
-rw-r--r-- | mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp | 35 |
1 files changed, 22 insertions, 13 deletions
diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp index 15f61ab9ce8..490b6695d84 100644 --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -156,10 +156,10 @@ LLVM::LLVMType LLVMTypeConverter::convertFunctionSignature( // int64_t sizes[Rank]; // omitted when rank == 0 // int64_t strides[Rank]; // omitted when rank == 0 // }; -static unsigned kPtrPosInMemRefDescriptor = 0; -static unsigned kOffsetPosInMemRefDescriptor = 1; -static unsigned kSizePosInMemRefDescriptor = 2; -static unsigned kStridePosInMemRefDescriptor = 3; +constexpr unsigned LLVMTypeConverter::kPtrPosInMemRefDescriptor; +constexpr unsigned LLVMTypeConverter::kOffsetPosInMemRefDescriptor; +constexpr unsigned LLVMTypeConverter::kSizePosInMemRefDescriptor; +constexpr unsigned LLVMTypeConverter::kStridePosInMemRefDescriptor; Type LLVMTypeConverter::convertMemRefType(MemRefType type) { int64_t offset; SmallVector<int64_t, 4> strides; @@ -282,7 +282,8 @@ public: Type elementTypePtr) { return builder.create<LLVM::ExtractValueOp>( loc, elementTypePtr, memref, - builder.getIndexArrayAttr(kPtrPosInMemRefDescriptor)); + builder.getIndexArrayAttr( + LLVMTypeConverter::kPtrPosInMemRefDescriptor)); } protected: @@ -763,11 +764,13 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> { memRefDescriptor = rewriter.create<LLVM::InsertValueOp>( op->getLoc(), structType, memRefDescriptor, allocated, - rewriter.getIndexArrayAttr(kPtrPosInMemRefDescriptor)); + rewriter.getIndexArrayAttr( + LLVMTypeConverter::kPtrPosInMemRefDescriptor)); memRefDescriptor = rewriter.create<LLVM::InsertValueOp>( op->getLoc(), structType, memRefDescriptor, createIndexConstant(rewriter, op->getLoc(), offset), - rewriter.getIndexArrayAttr(kOffsetPosInMemRefDescriptor)); + rewriter.getIndexArrayAttr( + LLVMTypeConverter::kOffsetPosInMemRefDescriptor)); if (type.getRank() == 0) // No size/stride descriptor in memref, return the descriptor value. @@ -798,10 +801,12 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> { int64_t index = indexedSize.index(); memRefDescriptor = rewriter.create<LLVM::InsertValueOp>( op->getLoc(), structType, memRefDescriptor, indexedSize.value(), - rewriter.getI64ArrayAttr({kSizePosInMemRefDescriptor, index})); + rewriter.getI64ArrayAttr( + {LLVMTypeConverter::kSizePosInMemRefDescriptor, index})); memRefDescriptor = rewriter.create<LLVM::InsertValueOp>( op->getLoc(), structType, memRefDescriptor, strideValues[index], - rewriter.getI64ArrayAttr({kStridePosInMemRefDescriptor, index})); + rewriter.getI64ArrayAttr( + {LLVMTypeConverter::kStridePosInMemRefDescriptor, index})); } // Return the final value of the descriptor. @@ -896,7 +901,8 @@ struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> { } auto type = transformed.memref()->getType().cast<LLVM::LLVMType>(); - Type elementPtrType = type.getStructElementType(kPtrPosInMemRefDescriptor); + Type elementPtrType = + type.getStructElementType(LLVMTypeConverter::kPtrPosInMemRefDescriptor); Value *bufferPtr = extractMemRefElementPtr( rewriter, op->getLoc(), transformed.memref(), elementPtrType); Value *casted = rewriter.create<LLVM::BitcastOp>( @@ -952,7 +958,8 @@ struct DimOpLowering : public LLVMLegalizationPattern<DimOp> { if (ShapedType::isDynamic(shape[index])) rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>( op, getIndexType(), transformed.memrefOrTensor(), - rewriter.getI64ArrayAttr({kSizePosInMemRefDescriptor, index})); + rewriter.getI64ArrayAttr( + {LLVMTypeConverter::kSizePosInMemRefDescriptor, index})); else // Use constant for static size. rewriter.replaceOp( @@ -1015,7 +1022,8 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> { offset == MemRefType::getDynamicStrideOrOffset() ? rewriter.create<LLVM::ExtractValueOp>( loc, indexTy, memRefDescriptor, - rewriter.getIndexArrayAttr(kOffsetPosInMemRefDescriptor)) + rewriter.getIndexArrayAttr( + LLVMTypeConverter::kOffsetPosInMemRefDescriptor)) : this->createIndexConstant(rewriter, loc, offset); for (int i = 0, e = indices.size(); i < e; ++i) { Value *stride; @@ -1028,7 +1036,8 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> { // Use dynamic stride. stride = rewriter.create<LLVM::ExtractValueOp>( loc, indexTy, memRefDescriptor, - rewriter.getIndexArrayAttr({kStridePosInMemRefDescriptor, i})); + rewriter.getIndexArrayAttr( + {LLVMTypeConverter::kStridePosInMemRefDescriptor, i})); } Value *additionalOffset = rewriter.create<LLVM::MulOp>(loc, indices[i], stride); |