summaryrefslogtreecommitdiffstats
path: root/mlir/lib
diff options
context:
space:
mode:
authorNicolas Vasilache <ntv@google.com>2019-11-12 07:06:18 -0800
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-11-12 07:06:54 -0800
commitf51a15533729cddc9907320b5ab963f7fc037aa0 (patch)
treef2c907948d881741510dc8574fe12dc468a72021 /mlir/lib
parent6582489219ab695a025457302a9e6924b1259176 (diff)
downloadbcm5719-llvm-f51a15533729cddc9907320b5ab963f7fc037aa0.tar.gz
bcm5719-llvm-f51a15533729cddc9907320b5ab963f7fc037aa0.zip
Add support for alignment attribute in std.alloc.
This CL adds an extra pointer to the memref descriptor to allow specifying alignment. In a previous implementation, we used 2 types: `linalg.buffer` and `view` where the buffer type was the unit of allocation/deallocation/alignment and `view` was the unit of indexing. After multiple discussions it was decided to use a single type, which conflates both, so the memref descriptor now needs to carry both pointers. This is consistent with the [RFC-Proposed Changes to MemRef and Tensor MLIR Types](https://groups.google.com/a/tensorflow.org/forum/#!searchin/mlir/std.view%7Csort:date/mlir/-wKHANzDNTg/4K6nUAp8AAAJ). PiperOrigin-RevId: 279959463
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp186
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp19
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp78
-rw-r--r--mlir/lib/Dialect/StandardOps/Ops.cpp4
4 files changed, 188 insertions, 99 deletions
diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
index 1584bd4d4ed..84eba820b80 100644
--- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
@@ -151,12 +151,14 @@ LLVM::LLVMType LLVMTypeConverter::convertFunctionSignature(
//
// template <typename Elem, size_t Rank>
// struct {
-// Elem *ptr;
+// Elem *allocatedPtr;
+// Elem *alignedPtr;
// int64_t offset;
// int64_t sizes[Rank]; // omitted when rank == 0
// int64_t strides[Rank]; // omitted when rank == 0
// };
-constexpr unsigned LLVMTypeConverter::kPtrPosInMemRefDescriptor;
+constexpr unsigned LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor;
+constexpr unsigned LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor;
constexpr unsigned LLVMTypeConverter::kOffsetPosInMemRefDescriptor;
constexpr unsigned LLVMTypeConverter::kSizePosInMemRefDescriptor;
constexpr unsigned LLVMTypeConverter::kStridePosInMemRefDescriptor;
@@ -175,9 +177,9 @@ Type LLVMTypeConverter::convertMemRefType(MemRefType type) {
auto rank = type.getRank();
if (rank > 0) {
auto arrayTy = LLVM::LLVMType::getArrayTy(indexTy, type.getRank());
- return LLVM::LLVMType::getStructTy(ptrTy, indexTy, arrayTy, arrayTy);
+ return LLVM::LLVMType::getStructTy(ptrTy, ptrTy, indexTy, arrayTy, arrayTy);
}
- return LLVM::LLVMType::getStructTy(ptrTy, indexTy);
+ return LLVM::LLVMType::getStructTy(ptrTy, ptrTy, indexTy);
}
// Convert an n-D vector type to an LLVM vector type via (n-1)-D array type when
@@ -276,14 +278,27 @@ public:
return builder.create<LLVM::ConstantOp>(loc, getIndexType(), attr);
}
- // Extract raw data pointer value from a value representing a memref.
- static Value *extractMemRefElementPtr(ConversionPatternRewriter &builder,
- Location loc, Value *memref,
- Type elementTypePtr) {
+ // Extract allocated data pointer value from a value representing a memref.
+ static Value *
+ extractAllocatedMemRefElementPtr(ConversionPatternRewriter &builder,
+ Location loc, Value *memref,
+ Type elementTypePtr) {
return builder.create<LLVM::ExtractValueOp>(
loc, elementTypePtr, memref,
- builder.getIndexArrayAttr(
- LLVMTypeConverter::kPtrPosInMemRefDescriptor));
+ builder.getI64ArrayAttr(
+ LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor));
+ }
+
+ // Extract properly aligned data pointer value from a value representing a
+ // memref.
+ static Value *
+ extractAlignedMemRefElementPtr(ConversionPatternRewriter &builder,
+ Location loc, Value *memref,
+ Type elementTypePtr) {
+ return builder.create<LLVM::ExtractValueOp>(
+ loc, elementTypePtr, memref,
+ builder.getI64ArrayAttr(
+ LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor));
}
protected:
@@ -442,7 +457,7 @@ void nDVectorIterate(const NDVectorTypeInfo &info, OpBuilder &builder,
if (coords.empty())
break;
assert(coords.size() == info.arraySizes.size());
- auto position = builder.getIndexArrayAttr(coords);
+ auto position = builder.getI64ArrayAttr(coords);
fun(position);
}
}
@@ -488,7 +503,7 @@ struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
auto type = this->lowering.convertType(op->getResult(i)->getType());
results.push_back(rewriter.create<LLVM::ExtractValueOp>(
op->getLoc(), type, newOp.getOperation()->getResult(0),
- rewriter.getIndexArrayAttr(i)));
+ rewriter.getI64ArrayAttr(i)));
}
rewriter.replaceOp(op, results);
return this->matchSuccess();
@@ -650,9 +665,16 @@ static bool isSupportedMemRefType(MemRefType type) {
// An `alloc` is converted into a definition of a memref descriptor value and
// a call to `malloc` to allocate the underlying data buffer. The memref
-// descriptor is of the LLVM structure type where the first element is a pointer
-// to the (typed) data buffer, and the remaining elements serve to store
-// dynamic sizes of the memref using LLVM-converted `index` type.
+// descriptor is of the LLVM structure type where:
+// 1. the first element is a pointer to the allocated (typed) data buffer,
+// 2. the second element is a pointer to the (typed) payload, aligned to the
+// specified alignment,
+// 3. the remaining elements serve to store all the sizes and strides of the
+// memref using LLVM-converted `index` type.
+//
+// Alignment is obtained by allocating `alignment - 1` more bytes than requested
+// and shifting the aligned pointer relative to the allocated memory. If
+// alignment is unspecified, the two pointers are equal.
struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
using LLVMLegalizationPattern<AllocOp>::LLVMLegalizationPattern;
@@ -678,6 +700,7 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
void rewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
+ auto loc = op->getLoc();
auto allocOp = cast<AllocOp>(op);
MemRefType type = allocOp.getType();
@@ -689,16 +712,15 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
unsigned i = 0;
for (int64_t s : type.getShape())
sizes.push_back(s == -1 ? operands[i++]
- : createIndexConstant(rewriter, op->getLoc(), s));
+ : createIndexConstant(rewriter, loc, s));
if (sizes.empty())
- sizes.push_back(createIndexConstant(rewriter, op->getLoc(), 1));
+ sizes.push_back(createIndexConstant(rewriter, loc, 1));
// Compute the total number of memref elements.
Value *cumulativeSize = sizes.front();
for (unsigned i = 1, e = sizes.size(); i < e; ++i)
cumulativeSize = rewriter.create<LLVM::MulOp>(
- op->getLoc(), getIndexType(),
- ArrayRef<Value *>{cumulativeSize, sizes[i]});
+ loc, getIndexType(), ArrayRef<Value *>{cumulativeSize, sizes[i]});
// Compute the size of an individual element. This emits the MLIR equivalent
// of the following sizeof(...) implementation in LLVM IR:
@@ -708,16 +730,14 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
auto elementType = type.getElementType();
auto convertedPtrType =
lowering.convertType(elementType).cast<LLVM::LLVMType>().getPointerTo();
- auto nullPtr =
- rewriter.create<LLVM::NullOp>(op->getLoc(), convertedPtrType);
- auto one = createIndexConstant(rewriter, op->getLoc(), 1);
- auto gep = rewriter.create<LLVM::GEPOp>(op->getLoc(), convertedPtrType,
+ auto nullPtr = rewriter.create<LLVM::NullOp>(loc, convertedPtrType);
+ auto one = createIndexConstant(rewriter, loc, 1);
+ auto gep = rewriter.create<LLVM::GEPOp>(loc, convertedPtrType,
ArrayRef<Value *>{nullPtr, one});
auto elementSize =
- rewriter.create<LLVM::PtrToIntOp>(op->getLoc(), getIndexType(), gep);
+ rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gep);
cumulativeSize = rewriter.create<LLVM::MulOp>(
- op->getLoc(), getIndexType(),
- ArrayRef<Value *>{cumulativeSize, elementSize});
+ loc, getIndexType(), ArrayRef<Value *>{cumulativeSize, elementSize});
// Insert the `malloc` declaration if it is not already present.
auto module = op->getParentOfType<ModuleOp>();
@@ -732,17 +752,24 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
// Allocate the underlying buffer and store a pointer to it in the MemRef
// descriptor.
+ Value *align = nullptr;
+ if (auto alignAttr = allocOp.alignment()) {
+ align = createIndexConstant(rewriter, loc,
+ alignAttr.getValue().getSExtValue());
+ cumulativeSize = rewriter.create<LLVM::SubOp>(
+ loc, rewriter.create<LLVM::AddOp>(loc, cumulativeSize, align), one);
+ }
Value *allocated =
rewriter
- .create<LLVM::CallOp>(op->getLoc(), getVoidPtrType(),
+ .create<LLVM::CallOp>(loc, getVoidPtrType(),
rewriter.getSymbolRefAttr(mallocFunc),
cumulativeSize)
.getResult(0);
auto structElementType = lowering.convertType(elementType);
auto elementPtrType = structElementType.cast<LLVM::LLVMType>().getPointerTo(
type.getMemorySpace());
- allocated = rewriter.create<LLVM::BitcastOp>(op->getLoc(), elementPtrType,
- ArrayRef<Value *>(allocated));
+ Value *bitcastAllocated = rewriter.create<LLVM::BitcastOp>(
+ loc, elementPtrType, ArrayRef<Value *>(allocated));
int64_t offset;
SmallVector<int64_t, 4> strides;
@@ -759,23 +786,44 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
// Create the MemRef descriptor.
auto structType = lowering.convertType(type);
- Value *memRefDescriptor = rewriter.create<LLVM::UndefOp>(
- op->getLoc(), structType, ArrayRef<Value *>{});
+ Value *memRefDescriptor =
+ rewriter.create<LLVM::UndefOp>(loc, structType, ArrayRef<Value *>{});
+ // Field 1: Allocated pointer, used for malloc/free.
+ memRefDescriptor = rewriter.create<LLVM::InsertValueOp>(
+ loc, structType, memRefDescriptor, bitcastAllocated,
+ rewriter.getI64ArrayAttr(
+ LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor));
+ // Field 2: Actual aligned pointer to payload.
+ Value *bitcastAligned = bitcastAllocated;
+ if (align) {
+ // offset = (align - (ptr % align))% align
+ Value *intVal = rewriter.create<LLVM::PtrToIntOp>(
+ loc, this->getIndexType(), allocated);
+ Value *ptrModAlign = rewriter.create<LLVM::URemOp>(loc, intVal, align);
+ Value *subbed = rewriter.create<LLVM::SubOp>(loc, align, ptrModAlign);
+ Value *offset = rewriter.create<LLVM::URemOp>(loc, subbed, align);
+ Value *aligned = rewriter.create<LLVM::GEPOp>(loc, allocated->getType(),
+ allocated, offset);
+ bitcastAligned = rewriter.create<LLVM::BitcastOp>(
+ loc, elementPtrType, ArrayRef<Value *>(aligned));
+ }
memRefDescriptor = rewriter.create<LLVM::InsertValueOp>(
- op->getLoc(), structType, memRefDescriptor, allocated,
- rewriter.getIndexArrayAttr(
- LLVMTypeConverter::kPtrPosInMemRefDescriptor));
+ loc, structType, memRefDescriptor, bitcastAligned,
+ rewriter.getI64ArrayAttr(
+ LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor));
+ // Field 3: Offset in aligned pointer.
memRefDescriptor = rewriter.create<LLVM::InsertValueOp>(
- op->getLoc(), structType, memRefDescriptor,
- createIndexConstant(rewriter, op->getLoc(), offset),
- rewriter.getIndexArrayAttr(
+ loc, structType, memRefDescriptor,
+ createIndexConstant(rewriter, loc, offset),
+ rewriter.getI64ArrayAttr(
LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
if (type.getRank() == 0)
// No size/stride descriptor in memref, return the descriptor value.
return rewriter.replaceOp(op, memRefDescriptor);
+ // Fields 4 and 5: Sizes and strides of the strided MemRef.
// Store all sizes in the descriptor. Only dynamic sizes are passed in as
// operands to AllocOp.
Value *runningStride = nullptr;
@@ -787,24 +835,23 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
if (strides[index] == MemRefType::getDynamicStrideOrOffset())
// Identity layout map is enforced in the match function, so we compute:
// `runningStride *= sizes[index]`
- runningStride = runningStride
- ? rewriter.create<LLVM::MulOp>(
- op->getLoc(), runningStride, sizes[index])
- : createIndexConstant(rewriter, op->getLoc(), 1);
- else
runningStride =
- createIndexConstant(rewriter, op->getLoc(), strides[index]);
+ runningStride
+ ? rewriter.create<LLVM::MulOp>(loc, runningStride, sizes[index])
+ : createIndexConstant(rewriter, loc, 1);
+ else
+ runningStride = createIndexConstant(rewriter, loc, strides[index]);
strideValues[index] = runningStride;
}
// Fill size and stride descriptors in memref.
for (auto indexedSize : llvm::enumerate(sizes)) {
int64_t index = indexedSize.index();
memRefDescriptor = rewriter.create<LLVM::InsertValueOp>(
- op->getLoc(), structType, memRefDescriptor, indexedSize.value(),
+ loc, structType, memRefDescriptor, indexedSize.value(),
rewriter.getI64ArrayAttr(
{LLVMTypeConverter::kSizePosInMemRefDescriptor, index}));
memRefDescriptor = rewriter.create<LLVM::InsertValueOp>(
- op->getLoc(), structType, memRefDescriptor, strideValues[index],
+ loc, structType, memRefDescriptor, strideValues[index],
rewriter.getI64ArrayAttr(
{LLVMTypeConverter::kStridePosInMemRefDescriptor, index}));
}
@@ -861,7 +908,7 @@ struct CallOpInterfaceLowering : public LLVMLegalizationPattern<CallOpType> {
auto type = this->lowering.convertType(op->getResult(i)->getType());
results.push_back(rewriter.create<LLVM::ExtractValueOp>(
op->getLoc(), type, newOp.getOperation()->getResult(0),
- rewriter.getIndexArrayAttr(i)));
+ rewriter.getI64ArrayAttr(i)));
}
rewriter.replaceOp(op, results);
@@ -901,9 +948,9 @@ struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
}
auto type = transformed.memref()->getType().cast<LLVM::LLVMType>();
- Type elementPtrType =
- type.getStructElementType(LLVMTypeConverter::kPtrPosInMemRefDescriptor);
- Value *bufferPtr = extractMemRefElementPtr(
+ Type elementPtrType = type.getStructElementType(
+ LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor);
+ Value *bufferPtr = extractAllocatedMemRefElementPtr(
rewriter, op->getLoc(), transformed.memref(), elementPtrType);
Value *casted = rewriter.create<LLVM::BitcastOp>(
op->getLoc(), getVoidPtrType(), bufferPtr);
@@ -1016,13 +1063,13 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> {
ArrayRef<int64_t> strides, int64_t offset,
ConversionPatternRewriter &rewriter) const {
auto indexTy = this->getIndexType();
- Value *base = this->extractMemRefElementPtr(rewriter, loc, memRefDescriptor,
- elementTypePtr);
+ Value *base = this->extractAlignedMemRefElementPtr(
+ rewriter, loc, memRefDescriptor, elementTypePtr);
Value *offsetValue =
offset == MemRefType::getDynamicStrideOrOffset()
? rewriter.create<LLVM::ExtractValueOp>(
loc, indexTy, memRefDescriptor,
- rewriter.getIndexArrayAttr(
+ rewriter.getI64ArrayAttr(
LLVMTypeConverter::kOffsetPosInMemRefDescriptor))
: this->createIndexConstant(rewriter, loc, offset);
for (int i = 0, e = indices.size(); i < e; ++i) {
@@ -1036,7 +1083,7 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> {
// Use dynamic stride.
stride = rewriter.create<LLVM::ExtractValueOp>(
loc, indexTy, memRefDescriptor,
- rewriter.getIndexArrayAttr(
+ rewriter.getI64ArrayAttr(
{LLVMTypeConverter::kStridePosInMemRefDescriptor, i}));
}
Value *additionalOffset =
@@ -1261,7 +1308,7 @@ struct ReturnOpLowering : public LLVMLegalizationPattern<ReturnOp> {
for (unsigned i = 0; i < numArguments; ++i) {
packed = rewriter.create<LLVM::InsertValueOp>(
op->getLoc(), packedType, packed, operands[i],
- rewriter.getIndexArrayAttr(i));
+ rewriter.getI64ArrayAttr(i));
}
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, llvm::makeArrayRef(packed),
llvm::ArrayRef<Block *>(),
@@ -1436,19 +1483,32 @@ struct ViewOpLowering : public LLVMLegalizationPattern<ViewOp> {
// Create the descriptor.
Value *desc = rewriter.create<LLVM::UndefOp>(loc, targetDescTy);
- // Copy the buffer pointer from the old descriptor to the new one.
+ // Field 1: Copy the allocated pointer, used for malloc/free.
Value *sourceDescriptor = adaptor.source();
+ Value *extracted = rewriter.create<LLVM::ExtractValueOp>(
+ loc, sourceElementTy.getPointerTo(), sourceDescriptor,
+ rewriter.getI64ArrayAttr(
+ LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor));
Value *bitcastPtr = rewriter.create<LLVM::BitcastOp>(
- loc, targetElementTy.getPointerTo(),
- rewriter.create<LLVM::ExtractValueOp>(
- loc, sourceElementTy.getPointerTo(), sourceDescriptor,
- rewriter.getI64ArrayAttr(
- LLVMTypeConverter::kPtrPosInMemRefDescriptor)));
+ loc, targetElementTy.getPointerTo(), extracted);
desc = rewriter.create<LLVM::InsertValueOp>(
loc, desc, bitcastPtr,
- rewriter.getI64ArrayAttr(LLVMTypeConverter::kPtrPosInMemRefDescriptor));
+ rewriter.getI64ArrayAttr(
+ LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor));
+
+ // Field 2: Copy the actual aligned pointer to payload.
+ extracted = rewriter.create<LLVM::ExtractValueOp>(
+ loc, sourceElementTy.getPointerTo(), sourceDescriptor,
+ rewriter.getI64ArrayAttr(
+ LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor));
+ bitcastPtr = rewriter.create<LLVM::BitcastOp>(
+ loc, targetElementTy.getPointerTo(), extracted);
+ desc = rewriter.create<LLVM::InsertValueOp>(
+ loc, desc, bitcastPtr,
+ rewriter.getI64ArrayAttr(
+ LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor));
- // Offset.
+ // Field 3: Copy the offset in aligned pointer.
unsigned numDynamicSizes = llvm::size(viewOp.getDynamicSizes());
(void)numDynamicSizes;
auto sizeAndOffsetOperands = adaptor.operands();
@@ -1467,7 +1527,7 @@ struct ViewOpLowering : public LLVMLegalizationPattern<ViewOp> {
if (viewMemRefType.getRank() == 0)
return rewriter.replaceOp(op, desc), matchSuccess();
- // Update sizes and strides.
+ // Fields 4 and 5: Update sizes and strides.
if (strides.back() != 1)
return op->emitWarning("cannot cast to non-contiguous shape"),
matchFailure();
diff --git a/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp
index 765c25ae227..5ccf740f2fb 100644
--- a/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp
@@ -189,9 +189,9 @@ public:
return matchFailure();
Type llvmSourceElementTy = llvmSourceDescriptorTy.getStructElementType(
- LLVMTypeConverter::kPtrPosInMemRefDescriptor);
+ LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor);
Type llvmTargetElementTy = llvmTargetDescriptorTy.getStructElementType(
- LLVMTypeConverter::kPtrPosInMemRefDescriptor);
+ LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor);
int64_t offset;
SmallVector<int64_t, 4> strides;
@@ -215,16 +215,27 @@ public:
// Create descriptor.
Value *desc = rewriter.create<LLVM::UndefOp>(loc, llvmTargetDescriptorTy);
+ // Set allocated ptr.
+ Value *allocated = rewriter.create<LLVM::ExtractValueOp>(
+ loc, llvmSourceElementTy, sourceMemRef,
+ rewriter.getIndexArrayAttr(
+ LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor));
+ allocated =
+ rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
+ desc = rewriter.create<LLVM::InsertValueOp>(
+ op->getLoc(), llvmTargetDescriptorTy, desc, allocated,
+ rewriter.getIndexArrayAttr(
+ LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor));
// Set ptr.
Value *ptr = rewriter.create<LLVM::ExtractValueOp>(
loc, llvmSourceElementTy, sourceMemRef,
rewriter.getIndexArrayAttr(
- LLVMTypeConverter::kPtrPosInMemRefDescriptor));
+ LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor));
ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
desc = rewriter.create<LLVM::InsertValueOp>(
op->getLoc(), llvmTargetDescriptorTy, desc, ptr,
rewriter.getIndexArrayAttr(
- LLVMTypeConverter::kPtrPosInMemRefDescriptor));
+ LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor));
// Fill offset 0.
auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp b/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp
index 2dc46bf7b2b..05f4bced063 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp
@@ -127,14 +127,6 @@ static Type convertLinalgType(Type t, LLVMTypeConverter &lowering) {
return Type();
}
-static constexpr int kBasePtrPosInBuffer = 0;
-static constexpr int kPtrPosInBuffer = 1;
-static constexpr int kSizePosInBuffer = 2;
-static constexpr int kPtrPosInView = 0;
-static constexpr int kOffsetPosInView = 1;
-static constexpr int kSizePosInView = 2;
-static constexpr int kStridePosInView = 3;
-
namespace {
/// Factor out the common information for all view conversions:
/// 1. common types in (standard and LLVM dialects)
@@ -224,12 +216,14 @@ public:
// TODO(ntv): extract sizes and emit asserts.
SmallVector<Value *, 4> strides(memRefType.getRank());
for (int i = 0, e = memRefType.getRank(); i < e; ++i)
- strides[i] =
- extractvalue(int64Ty, baseDesc, helper.pos({kStridePosInView, i}));
+ strides[i] = extractvalue(
+ int64Ty, baseDesc,
+ helper.pos({LLVMTypeConverter::kStridePosInMemRefDescriptor, i}));
// Compute base offset.
- Value *baseOffset =
- extractvalue(int64Ty, baseDesc, helper.pos(kOffsetPosInView));
+ Value *baseOffset = extractvalue(
+ int64Ty, baseDesc,
+ helper.pos(LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
for (int i = 0, e = memRefType.getRank(); i < e; ++i) {
Value *indexing = adaptor.indexings()[i];
Value *min = indexing;
@@ -238,12 +232,17 @@ public:
baseOffset = add(baseOffset, mul(min, strides[i]));
}
- // Insert base pointer.
- auto ptrPos = helper.pos(kPtrPosInView);
+ // Insert the base and aligned pointers.
+ auto ptrPos =
+ helper.pos(LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor);
+ desc = insertvalue(desc, extractvalue(elementTy, baseDesc, ptrPos), ptrPos);
+ ptrPos = helper.pos(LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor);
desc = insertvalue(desc, extractvalue(elementTy, baseDesc, ptrPos), ptrPos);
// Insert base offset.
- desc = insertvalue(desc, baseOffset, helper.pos(kOffsetPosInView));
+ desc = insertvalue(
+ desc, baseOffset,
+ helper.pos(LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
// Corner case, no sizes or strides: early return the descriptor.
if (helper.zeroDMemRef)
@@ -262,8 +261,9 @@ public:
Value *min = extractvalue(int64Ty, rangeDescriptor, helper.pos(0));
Value *max = extractvalue(int64Ty, rangeDescriptor, helper.pos(1));
Value *step = extractvalue(int64Ty, rangeDescriptor, helper.pos(2));
- Value *baseSize =
- extractvalue(int64Ty, baseDesc, helper.pos({kSizePosInView, rank}));
+ Value *baseSize = extractvalue(
+ int64Ty, baseDesc,
+ helper.pos({LLVMTypeConverter::kSizePosInMemRefDescriptor, rank}));
// Bound upper by base view upper bound.
max = llvm_select(llvm_icmp(ICmpPredicate::slt, max, baseSize), max,
baseSize);
@@ -272,10 +272,14 @@ public:
size =
llvm_select(llvm_icmp(ICmpPredicate::slt, size, zero), zero, size);
Value *stride = mul(strides[rank], step);
- desc =
- insertvalue(desc, size, helper.pos({kSizePosInView, numNewDims}));
- desc = insertvalue(desc, stride,
- helper.pos({kStridePosInView, numNewDims}));
+ desc = insertvalue(
+ desc, size,
+ helper.pos(
+ {LLVMTypeConverter::kSizePosInMemRefDescriptor, numNewDims}));
+ desc = insertvalue(
+ desc, stride,
+ helper.pos(
+ {LLVMTypeConverter::kStridePosInMemRefDescriptor, numNewDims}));
++numNewDims;
}
}
@@ -316,25 +320,39 @@ public:
Value *desc = helper.desc;
edsc::ScopedContext context(rewriter, op->getLoc());
- // Copy the base pointer from the old descriptor to the new one.
- ArrayAttr ptrPos = helper.pos(kPtrPosInView);
+ // Copy the base and aligned pointers from the old descriptor to the new
+ // one.
+ ArrayAttr ptrPos =
+ helper.pos(LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor);
+ desc = insertvalue(desc, extractvalue(elementTy, baseDesc, ptrPos), ptrPos);
+ ptrPos = helper.pos(LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor);
desc = insertvalue(desc, extractvalue(elementTy, baseDesc, ptrPos), ptrPos);
// Copy the offset pointer from the old descriptor to the new one.
- ArrayAttr offPos = helper.pos(kOffsetPosInView);
+ ArrayAttr offPos =
+ helper.pos(LLVMTypeConverter::kOffsetPosInMemRefDescriptor);
desc = insertvalue(desc, extractvalue(int64Ty, baseDesc, offPos), offPos);
// Iterate over the dimensions and apply size/stride permutation.
for (auto en : llvm::enumerate(transposeOp.permutation().getResults())) {
int sourcePos = en.index();
int targetPos = en.value().cast<AffineDimExpr>().getPosition();
- Value *size = extractvalue(int64Ty, baseDesc,
- helper.pos({kSizePosInView, sourcePos}));
- desc = insertvalue(desc, size, helper.pos({kSizePosInView, targetPos}));
- Value *stride = extractvalue(int64Ty, baseDesc,
- helper.pos({kStridePosInView, sourcePos}));
+ Value *size = extractvalue(
+ int64Ty, baseDesc,
+ helper.pos(
+ {LLVMTypeConverter::kSizePosInMemRefDescriptor, sourcePos}));
desc =
- insertvalue(desc, stride, helper.pos({kStridePosInView, targetPos}));
+ insertvalue(desc, size,
+ helper.pos({LLVMTypeConverter::kSizePosInMemRefDescriptor,
+ targetPos}));
+ Value *stride = extractvalue(
+ int64Ty, baseDesc,
+ helper.pos(
+ {LLVMTypeConverter::kStridePosInMemRefDescriptor, sourcePos}));
+ desc = insertvalue(
+ desc, stride,
+ helper.pos(
+ {LLVMTypeConverter::kStridePosInMemRefDescriptor, targetPos}));
}
rewriter.replaceOp(op, desc);
diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp
index 8c08868bc7a..f152ccfc50a 100644
--- a/mlir/lib/Dialect/StandardOps/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/Ops.cpp
@@ -438,8 +438,8 @@ struct SimplifyAllocConst : public OpRewritePattern<AllocOp> {
newMemRefType.getNumDynamicDims());
// Create and insert the alloc op for the new memref.
- auto newAlloc =
- rewriter.create<AllocOp>(alloc.getLoc(), newMemRefType, newOperands);
+ auto newAlloc = rewriter.create<AllocOp>(alloc.getLoc(), newMemRefType,
+ newOperands, IntegerAttr());
// Insert a cast so we have the same type as the old alloc.
auto resultCast = rewriter.create<MemRefCastOp>(alloc.getLoc(), newAlloc,
alloc.getType());
OpenPOWER on IntegriCloud