diff options
author | Nicolas Vasilache <ntv@google.com> | 2019-11-12 07:06:18 -0800 |
---|---|---|
committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-11-12 07:06:54 -0800 |
commit | f51a15533729cddc9907320b5ab963f7fc037aa0 (patch) | |
tree | f2c907948d881741510dc8574fe12dc468a72021 /mlir/lib | |
parent | 6582489219ab695a025457302a9e6924b1259176 (diff) | |
download | bcm5719-llvm-f51a15533729cddc9907320b5ab963f7fc037aa0.tar.gz bcm5719-llvm-f51a15533729cddc9907320b5ab963f7fc037aa0.zip |
Add support for alignment attribute in std.alloc.
This CL adds an extra pointer to the memref descriptor to allow specifying alignment.
In a previous implementation, we used 2 types: `linalg.buffer` and `view` where the buffer type was the unit of allocation/deallocation/alignment and `view` was the unit of indexing.
After multiple discussions it was decided to use a single type, which conflates both, so the memref descriptor now needs to carry both pointers.
This is consistent with the [RFC-Proposed Changes to MemRef and Tensor MLIR Types](https://groups.google.com/a/tensorflow.org/forum/#!searchin/mlir/std.view%7Csort:date/mlir/-wKHANzDNTg/4K6nUAp8AAAJ).
PiperOrigin-RevId: 279959463
Diffstat (limited to 'mlir/lib')
-rw-r--r-- | mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp | 186 | ||||
-rw-r--r-- | mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp | 19 | ||||
-rw-r--r-- | mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp | 78 | ||||
-rw-r--r-- | mlir/lib/Dialect/StandardOps/Ops.cpp | 4 |
4 files changed, 188 insertions, 99 deletions
diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp index 1584bd4d4ed..84eba820b80 100644 --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -151,12 +151,14 @@ LLVM::LLVMType LLVMTypeConverter::convertFunctionSignature( // // template <typename Elem, size_t Rank> // struct { -// Elem *ptr; +// Elem *allocatedPtr; +// Elem *alignedPtr; // int64_t offset; // int64_t sizes[Rank]; // omitted when rank == 0 // int64_t strides[Rank]; // omitted when rank == 0 // }; -constexpr unsigned LLVMTypeConverter::kPtrPosInMemRefDescriptor; +constexpr unsigned LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor; +constexpr unsigned LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor; constexpr unsigned LLVMTypeConverter::kOffsetPosInMemRefDescriptor; constexpr unsigned LLVMTypeConverter::kSizePosInMemRefDescriptor; constexpr unsigned LLVMTypeConverter::kStridePosInMemRefDescriptor; @@ -175,9 +177,9 @@ Type LLVMTypeConverter::convertMemRefType(MemRefType type) { auto rank = type.getRank(); if (rank > 0) { auto arrayTy = LLVM::LLVMType::getArrayTy(indexTy, type.getRank()); - return LLVM::LLVMType::getStructTy(ptrTy, indexTy, arrayTy, arrayTy); + return LLVM::LLVMType::getStructTy(ptrTy, ptrTy, indexTy, arrayTy, arrayTy); } - return LLVM::LLVMType::getStructTy(ptrTy, indexTy); + return LLVM::LLVMType::getStructTy(ptrTy, ptrTy, indexTy); } // Convert an n-D vector type to an LLVM vector type via (n-1)-D array type when @@ -276,14 +278,27 @@ public: return builder.create<LLVM::ConstantOp>(loc, getIndexType(), attr); } - // Extract raw data pointer value from a value representing a memref. - static Value *extractMemRefElementPtr(ConversionPatternRewriter &builder, - Location loc, Value *memref, - Type elementTypePtr) { + // Extract allocated data pointer value from a value representing a memref. + static Value * + extractAllocatedMemRefElementPtr(ConversionPatternRewriter &builder, + Location loc, Value *memref, + Type elementTypePtr) { return builder.create<LLVM::ExtractValueOp>( loc, elementTypePtr, memref, - builder.getIndexArrayAttr( - LLVMTypeConverter::kPtrPosInMemRefDescriptor)); + builder.getI64ArrayAttr( + LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor)); + } + + // Extract properly aligned data pointer value from a value representing a + // memref. + static Value * + extractAlignedMemRefElementPtr(ConversionPatternRewriter &builder, + Location loc, Value *memref, + Type elementTypePtr) { + return builder.create<LLVM::ExtractValueOp>( + loc, elementTypePtr, memref, + builder.getI64ArrayAttr( + LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor)); } protected: @@ -442,7 +457,7 @@ void nDVectorIterate(const NDVectorTypeInfo &info, OpBuilder &builder, if (coords.empty()) break; assert(coords.size() == info.arraySizes.size()); - auto position = builder.getIndexArrayAttr(coords); + auto position = builder.getI64ArrayAttr(coords); fun(position); } } @@ -488,7 +503,7 @@ struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> { auto type = this->lowering.convertType(op->getResult(i)->getType()); results.push_back(rewriter.create<LLVM::ExtractValueOp>( op->getLoc(), type, newOp.getOperation()->getResult(0), - rewriter.getIndexArrayAttr(i))); + rewriter.getI64ArrayAttr(i))); } rewriter.replaceOp(op, results); return this->matchSuccess(); @@ -650,9 +665,16 @@ static bool isSupportedMemRefType(MemRefType type) { // An `alloc` is converted into a definition of a memref descriptor value and // a call to `malloc` to allocate the underlying data buffer. The memref -// descriptor is of the LLVM structure type where the first element is a pointer -// to the (typed) data buffer, and the remaining elements serve to store -// dynamic sizes of the memref using LLVM-converted `index` type. +// descriptor is of the LLVM structure type where: +// 1. the first element is a pointer to the allocated (typed) data buffer, +// 2. the second element is a pointer to the (typed) payload, aligned to the +// specified alignment, +// 3. the remaining elements serve to store all the sizes and strides of the +// memref using LLVM-converted `index` type. +// +// Alignment is obtained by allocating `alignment - 1` more bytes than requested +// and shifting the aligned pointer relative to the allocated memory. If +// alignment is unspecified, the two pointers are equal. struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> { using LLVMLegalizationPattern<AllocOp>::LLVMLegalizationPattern; @@ -678,6 +700,7 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> { void rewrite(Operation *op, ArrayRef<Value *> operands, ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); auto allocOp = cast<AllocOp>(op); MemRefType type = allocOp.getType(); @@ -689,16 +712,15 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> { unsigned i = 0; for (int64_t s : type.getShape()) sizes.push_back(s == -1 ? operands[i++] - : createIndexConstant(rewriter, op->getLoc(), s)); + : createIndexConstant(rewriter, loc, s)); if (sizes.empty()) - sizes.push_back(createIndexConstant(rewriter, op->getLoc(), 1)); + sizes.push_back(createIndexConstant(rewriter, loc, 1)); // Compute the total number of memref elements. Value *cumulativeSize = sizes.front(); for (unsigned i = 1, e = sizes.size(); i < e; ++i) cumulativeSize = rewriter.create<LLVM::MulOp>( - op->getLoc(), getIndexType(), - ArrayRef<Value *>{cumulativeSize, sizes[i]}); + loc, getIndexType(), ArrayRef<Value *>{cumulativeSize, sizes[i]}); // Compute the size of an individual element. This emits the MLIR equivalent // of the following sizeof(...) implementation in LLVM IR: @@ -708,16 +730,14 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> { auto elementType = type.getElementType(); auto convertedPtrType = lowering.convertType(elementType).cast<LLVM::LLVMType>().getPointerTo(); - auto nullPtr = - rewriter.create<LLVM::NullOp>(op->getLoc(), convertedPtrType); - auto one = createIndexConstant(rewriter, op->getLoc(), 1); - auto gep = rewriter.create<LLVM::GEPOp>(op->getLoc(), convertedPtrType, + auto nullPtr = rewriter.create<LLVM::NullOp>(loc, convertedPtrType); + auto one = createIndexConstant(rewriter, loc, 1); + auto gep = rewriter.create<LLVM::GEPOp>(loc, convertedPtrType, ArrayRef<Value *>{nullPtr, one}); auto elementSize = - rewriter.create<LLVM::PtrToIntOp>(op->getLoc(), getIndexType(), gep); + rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gep); cumulativeSize = rewriter.create<LLVM::MulOp>( - op->getLoc(), getIndexType(), - ArrayRef<Value *>{cumulativeSize, elementSize}); + loc, getIndexType(), ArrayRef<Value *>{cumulativeSize, elementSize}); // Insert the `malloc` declaration if it is not already present. auto module = op->getParentOfType<ModuleOp>(); @@ -732,17 +752,24 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> { // Allocate the underlying buffer and store a pointer to it in the MemRef // descriptor. + Value *align = nullptr; + if (auto alignAttr = allocOp.alignment()) { + align = createIndexConstant(rewriter, loc, + alignAttr.getValue().getSExtValue()); + cumulativeSize = rewriter.create<LLVM::SubOp>( + loc, rewriter.create<LLVM::AddOp>(loc, cumulativeSize, align), one); + } Value *allocated = rewriter - .create<LLVM::CallOp>(op->getLoc(), getVoidPtrType(), + .create<LLVM::CallOp>(loc, getVoidPtrType(), rewriter.getSymbolRefAttr(mallocFunc), cumulativeSize) .getResult(0); auto structElementType = lowering.convertType(elementType); auto elementPtrType = structElementType.cast<LLVM::LLVMType>().getPointerTo( type.getMemorySpace()); - allocated = rewriter.create<LLVM::BitcastOp>(op->getLoc(), elementPtrType, - ArrayRef<Value *>(allocated)); + Value *bitcastAllocated = rewriter.create<LLVM::BitcastOp>( + loc, elementPtrType, ArrayRef<Value *>(allocated)); int64_t offset; SmallVector<int64_t, 4> strides; @@ -759,23 +786,44 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> { // Create the MemRef descriptor. auto structType = lowering.convertType(type); - Value *memRefDescriptor = rewriter.create<LLVM::UndefOp>( - op->getLoc(), structType, ArrayRef<Value *>{}); + Value *memRefDescriptor = + rewriter.create<LLVM::UndefOp>(loc, structType, ArrayRef<Value *>{}); + // Field 1: Allocated pointer, used for malloc/free. + memRefDescriptor = rewriter.create<LLVM::InsertValueOp>( + loc, structType, memRefDescriptor, bitcastAllocated, + rewriter.getI64ArrayAttr( + LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor)); + // Field 2: Actual aligned pointer to payload. + Value *bitcastAligned = bitcastAllocated; + if (align) { + // offset = (align - (ptr % align))% align + Value *intVal = rewriter.create<LLVM::PtrToIntOp>( + loc, this->getIndexType(), allocated); + Value *ptrModAlign = rewriter.create<LLVM::URemOp>(loc, intVal, align); + Value *subbed = rewriter.create<LLVM::SubOp>(loc, align, ptrModAlign); + Value *offset = rewriter.create<LLVM::URemOp>(loc, subbed, align); + Value *aligned = rewriter.create<LLVM::GEPOp>(loc, allocated->getType(), + allocated, offset); + bitcastAligned = rewriter.create<LLVM::BitcastOp>( + loc, elementPtrType, ArrayRef<Value *>(aligned)); + } memRefDescriptor = rewriter.create<LLVM::InsertValueOp>( - op->getLoc(), structType, memRefDescriptor, allocated, - rewriter.getIndexArrayAttr( - LLVMTypeConverter::kPtrPosInMemRefDescriptor)); + loc, structType, memRefDescriptor, bitcastAligned, + rewriter.getI64ArrayAttr( + LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor)); + // Field 3: Offset in aligned pointer. memRefDescriptor = rewriter.create<LLVM::InsertValueOp>( - op->getLoc(), structType, memRefDescriptor, - createIndexConstant(rewriter, op->getLoc(), offset), - rewriter.getIndexArrayAttr( + loc, structType, memRefDescriptor, + createIndexConstant(rewriter, loc, offset), + rewriter.getI64ArrayAttr( LLVMTypeConverter::kOffsetPosInMemRefDescriptor)); if (type.getRank() == 0) // No size/stride descriptor in memref, return the descriptor value. return rewriter.replaceOp(op, memRefDescriptor); + // Fields 4 and 5: Sizes and strides of the strided MemRef. // Store all sizes in the descriptor. Only dynamic sizes are passed in as // operands to AllocOp. Value *runningStride = nullptr; @@ -787,24 +835,23 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> { if (strides[index] == MemRefType::getDynamicStrideOrOffset()) // Identity layout map is enforced in the match function, so we compute: // `runningStride *= sizes[index]` - runningStride = runningStride - ? rewriter.create<LLVM::MulOp>( - op->getLoc(), runningStride, sizes[index]) - : createIndexConstant(rewriter, op->getLoc(), 1); - else runningStride = - createIndexConstant(rewriter, op->getLoc(), strides[index]); + runningStride + ? rewriter.create<LLVM::MulOp>(loc, runningStride, sizes[index]) + : createIndexConstant(rewriter, loc, 1); + else + runningStride = createIndexConstant(rewriter, loc, strides[index]); strideValues[index] = runningStride; } // Fill size and stride descriptors in memref. for (auto indexedSize : llvm::enumerate(sizes)) { int64_t index = indexedSize.index(); memRefDescriptor = rewriter.create<LLVM::InsertValueOp>( - op->getLoc(), structType, memRefDescriptor, indexedSize.value(), + loc, structType, memRefDescriptor, indexedSize.value(), rewriter.getI64ArrayAttr( {LLVMTypeConverter::kSizePosInMemRefDescriptor, index})); memRefDescriptor = rewriter.create<LLVM::InsertValueOp>( - op->getLoc(), structType, memRefDescriptor, strideValues[index], + loc, structType, memRefDescriptor, strideValues[index], rewriter.getI64ArrayAttr( {LLVMTypeConverter::kStridePosInMemRefDescriptor, index})); } @@ -861,7 +908,7 @@ struct CallOpInterfaceLowering : public LLVMLegalizationPattern<CallOpType> { auto type = this->lowering.convertType(op->getResult(i)->getType()); results.push_back(rewriter.create<LLVM::ExtractValueOp>( op->getLoc(), type, newOp.getOperation()->getResult(0), - rewriter.getIndexArrayAttr(i))); + rewriter.getI64ArrayAttr(i))); } rewriter.replaceOp(op, results); @@ -901,9 +948,9 @@ struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> { } auto type = transformed.memref()->getType().cast<LLVM::LLVMType>(); - Type elementPtrType = - type.getStructElementType(LLVMTypeConverter::kPtrPosInMemRefDescriptor); - Value *bufferPtr = extractMemRefElementPtr( + Type elementPtrType = type.getStructElementType( + LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor); + Value *bufferPtr = extractAllocatedMemRefElementPtr( rewriter, op->getLoc(), transformed.memref(), elementPtrType); Value *casted = rewriter.create<LLVM::BitcastOp>( op->getLoc(), getVoidPtrType(), bufferPtr); @@ -1016,13 +1063,13 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> { ArrayRef<int64_t> strides, int64_t offset, ConversionPatternRewriter &rewriter) const { auto indexTy = this->getIndexType(); - Value *base = this->extractMemRefElementPtr(rewriter, loc, memRefDescriptor, - elementTypePtr); + Value *base = this->extractAlignedMemRefElementPtr( + rewriter, loc, memRefDescriptor, elementTypePtr); Value *offsetValue = offset == MemRefType::getDynamicStrideOrOffset() ? rewriter.create<LLVM::ExtractValueOp>( loc, indexTy, memRefDescriptor, - rewriter.getIndexArrayAttr( + rewriter.getI64ArrayAttr( LLVMTypeConverter::kOffsetPosInMemRefDescriptor)) : this->createIndexConstant(rewriter, loc, offset); for (int i = 0, e = indices.size(); i < e; ++i) { @@ -1036,7 +1083,7 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> { // Use dynamic stride. stride = rewriter.create<LLVM::ExtractValueOp>( loc, indexTy, memRefDescriptor, - rewriter.getIndexArrayAttr( + rewriter.getI64ArrayAttr( {LLVMTypeConverter::kStridePosInMemRefDescriptor, i})); } Value *additionalOffset = @@ -1261,7 +1308,7 @@ struct ReturnOpLowering : public LLVMLegalizationPattern<ReturnOp> { for (unsigned i = 0; i < numArguments; ++i) { packed = rewriter.create<LLVM::InsertValueOp>( op->getLoc(), packedType, packed, operands[i], - rewriter.getIndexArrayAttr(i)); + rewriter.getI64ArrayAttr(i)); } rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, llvm::makeArrayRef(packed), llvm::ArrayRef<Block *>(), @@ -1436,19 +1483,32 @@ struct ViewOpLowering : public LLVMLegalizationPattern<ViewOp> { // Create the descriptor. Value *desc = rewriter.create<LLVM::UndefOp>(loc, targetDescTy); - // Copy the buffer pointer from the old descriptor to the new one. + // Field 1: Copy the allocated pointer, used for malloc/free. Value *sourceDescriptor = adaptor.source(); + Value *extracted = rewriter.create<LLVM::ExtractValueOp>( + loc, sourceElementTy.getPointerTo(), sourceDescriptor, + rewriter.getI64ArrayAttr( + LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor)); Value *bitcastPtr = rewriter.create<LLVM::BitcastOp>( - loc, targetElementTy.getPointerTo(), - rewriter.create<LLVM::ExtractValueOp>( - loc, sourceElementTy.getPointerTo(), sourceDescriptor, - rewriter.getI64ArrayAttr( - LLVMTypeConverter::kPtrPosInMemRefDescriptor))); + loc, targetElementTy.getPointerTo(), extracted); desc = rewriter.create<LLVM::InsertValueOp>( loc, desc, bitcastPtr, - rewriter.getI64ArrayAttr(LLVMTypeConverter::kPtrPosInMemRefDescriptor)); + rewriter.getI64ArrayAttr( + LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor)); + + // Field 2: Copy the actual aligned pointer to payload. + extracted = rewriter.create<LLVM::ExtractValueOp>( + loc, sourceElementTy.getPointerTo(), sourceDescriptor, + rewriter.getI64ArrayAttr( + LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor)); + bitcastPtr = rewriter.create<LLVM::BitcastOp>( + loc, targetElementTy.getPointerTo(), extracted); + desc = rewriter.create<LLVM::InsertValueOp>( + loc, desc, bitcastPtr, + rewriter.getI64ArrayAttr( + LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor)); - // Offset. + // Field 3: Copy the offset in aligned pointer. unsigned numDynamicSizes = llvm::size(viewOp.getDynamicSizes()); (void)numDynamicSizes; auto sizeAndOffsetOperands = adaptor.operands(); @@ -1467,7 +1527,7 @@ struct ViewOpLowering : public LLVMLegalizationPattern<ViewOp> { if (viewMemRefType.getRank() == 0) return rewriter.replaceOp(op, desc), matchSuccess(); - // Update sizes and strides. + // Fields 4 and 5: Update sizes and strides. if (strides.back() != 1) return op->emitWarning("cannot cast to non-contiguous shape"), matchFailure(); diff --git a/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp index 765c25ae227..5ccf740f2fb 100644 --- a/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp @@ -189,9 +189,9 @@ public: return matchFailure(); Type llvmSourceElementTy = llvmSourceDescriptorTy.getStructElementType( - LLVMTypeConverter::kPtrPosInMemRefDescriptor); + LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor); Type llvmTargetElementTy = llvmTargetDescriptorTy.getStructElementType( - LLVMTypeConverter::kPtrPosInMemRefDescriptor); + LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor); int64_t offset; SmallVector<int64_t, 4> strides; @@ -215,16 +215,27 @@ public: // Create descriptor. Value *desc = rewriter.create<LLVM::UndefOp>(loc, llvmTargetDescriptorTy); + // Set allocated ptr. + Value *allocated = rewriter.create<LLVM::ExtractValueOp>( + loc, llvmSourceElementTy, sourceMemRef, + rewriter.getIndexArrayAttr( + LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor)); + allocated = + rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated); + desc = rewriter.create<LLVM::InsertValueOp>( + op->getLoc(), llvmTargetDescriptorTy, desc, allocated, + rewriter.getIndexArrayAttr( + LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor)); // Set ptr. Value *ptr = rewriter.create<LLVM::ExtractValueOp>( loc, llvmSourceElementTy, sourceMemRef, rewriter.getIndexArrayAttr( - LLVMTypeConverter::kPtrPosInMemRefDescriptor)); + LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor)); ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr); desc = rewriter.create<LLVM::InsertValueOp>( op->getLoc(), llvmTargetDescriptorTy, desc, ptr, rewriter.getIndexArrayAttr( - LLVMTypeConverter::kPtrPosInMemRefDescriptor)); + LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor)); // Fill offset 0. auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0); auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr); diff --git a/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp b/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp index 2dc46bf7b2b..05f4bced063 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp @@ -127,14 +127,6 @@ static Type convertLinalgType(Type t, LLVMTypeConverter &lowering) { return Type(); } -static constexpr int kBasePtrPosInBuffer = 0; -static constexpr int kPtrPosInBuffer = 1; -static constexpr int kSizePosInBuffer = 2; -static constexpr int kPtrPosInView = 0; -static constexpr int kOffsetPosInView = 1; -static constexpr int kSizePosInView = 2; -static constexpr int kStridePosInView = 3; - namespace { /// Factor out the common information for all view conversions: /// 1. common types in (standard and LLVM dialects) @@ -224,12 +216,14 @@ public: // TODO(ntv): extract sizes and emit asserts. SmallVector<Value *, 4> strides(memRefType.getRank()); for (int i = 0, e = memRefType.getRank(); i < e; ++i) - strides[i] = - extractvalue(int64Ty, baseDesc, helper.pos({kStridePosInView, i})); + strides[i] = extractvalue( + int64Ty, baseDesc, + helper.pos({LLVMTypeConverter::kStridePosInMemRefDescriptor, i})); // Compute base offset. - Value *baseOffset = - extractvalue(int64Ty, baseDesc, helper.pos(kOffsetPosInView)); + Value *baseOffset = extractvalue( + int64Ty, baseDesc, + helper.pos(LLVMTypeConverter::kOffsetPosInMemRefDescriptor)); for (int i = 0, e = memRefType.getRank(); i < e; ++i) { Value *indexing = adaptor.indexings()[i]; Value *min = indexing; @@ -238,12 +232,17 @@ public: baseOffset = add(baseOffset, mul(min, strides[i])); } - // Insert base pointer. - auto ptrPos = helper.pos(kPtrPosInView); + // Insert the base and aligned pointers. + auto ptrPos = + helper.pos(LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor); + desc = insertvalue(desc, extractvalue(elementTy, baseDesc, ptrPos), ptrPos); + ptrPos = helper.pos(LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor); desc = insertvalue(desc, extractvalue(elementTy, baseDesc, ptrPos), ptrPos); // Insert base offset. - desc = insertvalue(desc, baseOffset, helper.pos(kOffsetPosInView)); + desc = insertvalue( + desc, baseOffset, + helper.pos(LLVMTypeConverter::kOffsetPosInMemRefDescriptor)); // Corner case, no sizes or strides: early return the descriptor. if (helper.zeroDMemRef) @@ -262,8 +261,9 @@ public: Value *min = extractvalue(int64Ty, rangeDescriptor, helper.pos(0)); Value *max = extractvalue(int64Ty, rangeDescriptor, helper.pos(1)); Value *step = extractvalue(int64Ty, rangeDescriptor, helper.pos(2)); - Value *baseSize = - extractvalue(int64Ty, baseDesc, helper.pos({kSizePosInView, rank})); + Value *baseSize = extractvalue( + int64Ty, baseDesc, + helper.pos({LLVMTypeConverter::kSizePosInMemRefDescriptor, rank})); // Bound upper by base view upper bound. max = llvm_select(llvm_icmp(ICmpPredicate::slt, max, baseSize), max, baseSize); @@ -272,10 +272,14 @@ public: size = llvm_select(llvm_icmp(ICmpPredicate::slt, size, zero), zero, size); Value *stride = mul(strides[rank], step); - desc = - insertvalue(desc, size, helper.pos({kSizePosInView, numNewDims})); - desc = insertvalue(desc, stride, - helper.pos({kStridePosInView, numNewDims})); + desc = insertvalue( + desc, size, + helper.pos( + {LLVMTypeConverter::kSizePosInMemRefDescriptor, numNewDims})); + desc = insertvalue( + desc, stride, + helper.pos( + {LLVMTypeConverter::kStridePosInMemRefDescriptor, numNewDims})); ++numNewDims; } } @@ -316,25 +320,39 @@ public: Value *desc = helper.desc; edsc::ScopedContext context(rewriter, op->getLoc()); - // Copy the base pointer from the old descriptor to the new one. - ArrayAttr ptrPos = helper.pos(kPtrPosInView); + // Copy the base and aligned pointers from the old descriptor to the new + // one. + ArrayAttr ptrPos = + helper.pos(LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor); + desc = insertvalue(desc, extractvalue(elementTy, baseDesc, ptrPos), ptrPos); + ptrPos = helper.pos(LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor); desc = insertvalue(desc, extractvalue(elementTy, baseDesc, ptrPos), ptrPos); // Copy the offset pointer from the old descriptor to the new one. - ArrayAttr offPos = helper.pos(kOffsetPosInView); + ArrayAttr offPos = + helper.pos(LLVMTypeConverter::kOffsetPosInMemRefDescriptor); desc = insertvalue(desc, extractvalue(int64Ty, baseDesc, offPos), offPos); // Iterate over the dimensions and apply size/stride permutation. for (auto en : llvm::enumerate(transposeOp.permutation().getResults())) { int sourcePos = en.index(); int targetPos = en.value().cast<AffineDimExpr>().getPosition(); - Value *size = extractvalue(int64Ty, baseDesc, - helper.pos({kSizePosInView, sourcePos})); - desc = insertvalue(desc, size, helper.pos({kSizePosInView, targetPos})); - Value *stride = extractvalue(int64Ty, baseDesc, - helper.pos({kStridePosInView, sourcePos})); + Value *size = extractvalue( + int64Ty, baseDesc, + helper.pos( + {LLVMTypeConverter::kSizePosInMemRefDescriptor, sourcePos})); desc = - insertvalue(desc, stride, helper.pos({kStridePosInView, targetPos})); + insertvalue(desc, size, + helper.pos({LLVMTypeConverter::kSizePosInMemRefDescriptor, + targetPos})); + Value *stride = extractvalue( + int64Ty, baseDesc, + helper.pos( + {LLVMTypeConverter::kStridePosInMemRefDescriptor, sourcePos})); + desc = insertvalue( + desc, stride, + helper.pos( + {LLVMTypeConverter::kStridePosInMemRefDescriptor, targetPos})); } rewriter.replaceOp(op, desc); diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp index 8c08868bc7a..f152ccfc50a 100644 --- a/mlir/lib/Dialect/StandardOps/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/Ops.cpp @@ -438,8 +438,8 @@ struct SimplifyAllocConst : public OpRewritePattern<AllocOp> { newMemRefType.getNumDynamicDims()); // Create and insert the alloc op for the new memref. - auto newAlloc = - rewriter.create<AllocOp>(alloc.getLoc(), newMemRefType, newOperands); + auto newAlloc = rewriter.create<AllocOp>(alloc.getLoc(), newMemRefType, + newOperands, IntegerAttr()); // Insert a cast so we have the same type as the old alloc. auto resultCast = rewriter.create<MemRefCastOp>(alloc.getLoc(), newAlloc, alloc.getType()); |