summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Conversion/StandardToLLVM
diff options
context:
space:
mode:
authorAlex Zinenko <zinenko@google.com>2019-11-14 00:48:41 -0800
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-11-14 00:49:12 -0800
commitee5c2256ef31fefc92ad59f78b0649b145dc0eb0 (patch)
tree291dc6c6f3248c01da15931b9a887b0497af28c7 /mlir/lib/Conversion/StandardToLLVM
parentd1c99e10d05508855b51ec391c8f1c4a7f4aa14b (diff)
downloadbcm5719-llvm-ee5c2256ef31fefc92ad59f78b0649b145dc0eb0.tar.gz
bcm5719-llvm-ee5c2256ef31fefc92ad59f78b0649b145dc0eb0.zip
Concentrate memref descriptor manipulation logic in one place
Memref descriptor is becoming increasingly complex. Memrefs are manipulated by multiple standard instructions, each of which has a non-trivial lowering to the LLVM dialect. This leads to verbose code that manipulates the descriptors exposing the internals of insert/extractelement opreations. Implement a wrapper class that contains a memref descriptor and provides semantically named methods that build the primitive IR operations instead. PiperOrigin-RevId: 280371225
Diffstat (limited to 'mlir/lib/Conversion/StandardToLLVM')
-rw-r--r--mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp339
1 files changed, 174 insertions, 165 deletions
diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
index 791a237c350..0641a6b9ab0 100644
--- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
@@ -235,6 +235,125 @@ LLVMOpLowering::LLVMOpLowering(StringRef rootOpName, MLIRContext *context,
: ConversionPattern(rootOpName, benefit, context), lowering(lowering_) {}
namespace {
+/// Helper class to produce LLVM dialect operations extracting or inserting
+/// elements of a MemRef descriptor. Wraps a Value pointing to the descriptor.
+/// The Value may be null, in which case none of the operations are valid.
+class MemRefDescriptor {
+public:
+ /// Construct a helper for the given descriptor value.
+ explicit MemRefDescriptor(Value *descriptor) : value(descriptor) {
+ if (value) {
+ structType = value->getType().cast<LLVM::LLVMType>();
+ indexType = value->getType().cast<LLVM::LLVMType>().getStructElementType(
+ LLVMTypeConverter::kOffsetPosInMemRefDescriptor);
+ }
+ }
+
+ /// Builds IR creating an `undef` value of the descriptor type.
+ static MemRefDescriptor undef(OpBuilder &builder, Location loc,
+ Type descriptorType) {
+ Value *descriptor = builder.create<LLVM::UndefOp>(
+ loc, descriptorType.cast<LLVM::LLVMType>());
+ return MemRefDescriptor(descriptor);
+ }
+
+ /// Builds IR extracting the allocated pointer from the descriptor.
+ Value *allocatedPtr(OpBuilder &builder, Location loc) {
+ return extractPtr(builder, loc,
+ LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor);
+ }
+
+ /// Builds IR inserting the allocated pointer into the descriptor.
+ void setAllocatedPtr(OpBuilder &builder, Location loc, Value *ptr) {
+ setPtr(builder, loc, LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor,
+ ptr);
+ }
+
+ /// Builds IR extracting the aligned pointer from the descriptor.
+ Value *alignedPtr(OpBuilder &builder, Location loc) {
+ return extractPtr(builder, loc,
+ LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor);
+ }
+
+ /// Builds IR inserting the aligned pointer into the descriptor.
+ void setAlignedPtr(OpBuilder &builder, Location loc, Value *ptr) {
+ setPtr(builder, loc, LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor,
+ ptr);
+ }
+
+ /// Builds IR extracting the offset from the descriptor.
+ Value *offset(OpBuilder &builder, Location loc) {
+ return builder.create<LLVM::ExtractValueOp>(
+ loc, indexType, value,
+ builder.getI64ArrayAttr(
+ LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
+ }
+
+ /// Builds IR inserting the offset into the descriptor.
+ void setOffset(OpBuilder &builder, Location loc, Value *offset) {
+ value = builder.create<LLVM::InsertValueOp>(
+ loc, structType, value, offset,
+ builder.getI64ArrayAttr(
+ LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
+ }
+
+ /// Builds IR extracting the pos-th size from the descriptor.
+ Value *size(OpBuilder &builder, Location loc, unsigned pos) {
+ return builder.create<LLVM::ExtractValueOp>(
+ loc, indexType, value,
+ builder.getI64ArrayAttr(
+ {LLVMTypeConverter::kSizePosInMemRefDescriptor, pos}));
+ }
+
+ /// Builds IR inserting the pos-th size into the descriptor
+ void setSize(OpBuilder &builder, Location loc, unsigned pos, Value *size) {
+ value = builder.create<LLVM::InsertValueOp>(
+ loc, structType, value, size,
+ builder.getI64ArrayAttr(
+ {LLVMTypeConverter::kSizePosInMemRefDescriptor, pos}));
+ }
+
+ /// Builds IR extracting the pos-th size from the descriptor.
+ Value *stride(OpBuilder &builder, Location loc, unsigned pos) {
+ return builder.create<LLVM::ExtractValueOp>(
+ loc, indexType, value,
+ builder.getI64ArrayAttr(
+ {LLVMTypeConverter::kStridePosInMemRefDescriptor, pos}));
+ }
+
+ /// Builds IR inserting the pos-th stride into the descriptor
+ void setStride(OpBuilder &builder, Location loc, unsigned pos,
+ Value *stride) {
+ value = builder.create<LLVM::InsertValueOp>(
+ loc, structType, value, stride,
+ builder.getI64ArrayAttr(
+ {LLVMTypeConverter::kStridePosInMemRefDescriptor, pos}));
+ }
+
+ /*implicit*/ operator Value *() { return value; }
+
+private:
+ Value *extractPtr(OpBuilder &builder, Location loc, unsigned pos) {
+ Type type = structType.getStructElementType(pos);
+ return builder.create<LLVM::ExtractValueOp>(loc, type, value,
+ builder.getI64ArrayAttr(pos));
+ }
+
+ void setPtr(OpBuilder &builder, Location loc, unsigned pos, Value *ptr) {
+ value = builder.create<LLVM::InsertValueOp>(loc, structType, value, ptr,
+ builder.getI64ArrayAttr(pos));
+ }
+
+ // Cached descriptor type.
+ LLVM::LLVMType structType;
+
+ // Cached index type.
+ LLVM::LLVMType indexType;
+
+ // Actual descriptor.
+ Value *value;
+};
+
// Base class for Standard to LLVM IR op conversions. Matches the Op type
// provided as template argument. Carries a reference to the LLVM dialect in
// case it is necessary for rewriters.
@@ -278,29 +397,6 @@ public:
return builder.create<LLVM::ConstantOp>(loc, getIndexType(), attr);
}
- // Extract allocated data pointer value from a value representing a memref.
- static Value *
- extractAllocatedMemRefElementPtr(ConversionPatternRewriter &builder,
- Location loc, Value *memref,
- Type elementTypePtr) {
- return builder.create<LLVM::ExtractValueOp>(
- loc, elementTypePtr, memref,
- builder.getI64ArrayAttr(
- LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor));
- }
-
- // Extract properly aligned data pointer value from a value representing a
- // memref.
- static Value *
- extractAlignedMemRefElementPtr(ConversionPatternRewriter &builder,
- Location loc, Value *memref,
- Type elementTypePtr) {
- return builder.create<LLVM::ExtractValueOp>(
- loc, elementTypePtr, memref,
- builder.getI64ArrayAttr(
- LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor));
- }
-
protected:
LLVM::LLVMDialect &dialect;
};
@@ -786,14 +882,10 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
// Create the MemRef descriptor.
auto structType = lowering.convertType(type);
- Value *memRefDescriptor =
- rewriter.create<LLVM::UndefOp>(loc, structType, ArrayRef<Value *>{});
-
+ auto memRefDescriptor = MemRefDescriptor::undef(rewriter, loc, structType);
// Field 1: Allocated pointer, used for malloc/free.
- memRefDescriptor = rewriter.create<LLVM::InsertValueOp>(
- loc, structType, memRefDescriptor, bitcastAllocated,
- rewriter.getI64ArrayAttr(
- LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor));
+ memRefDescriptor.setAllocatedPtr(rewriter, loc, bitcastAllocated);
+
// Field 2: Actual aligned pointer to payload.
Value *bitcastAligned = bitcastAllocated;
if (align) {
@@ -808,20 +900,15 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
bitcastAligned = rewriter.create<LLVM::BitcastOp>(
loc, elementPtrType, ArrayRef<Value *>(aligned));
}
- memRefDescriptor = rewriter.create<LLVM::InsertValueOp>(
- loc, structType, memRefDescriptor, bitcastAligned,
- rewriter.getI64ArrayAttr(
- LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor));
+ memRefDescriptor.setAlignedPtr(rewriter, loc, bitcastAligned);
+
// Field 3: Offset in aligned pointer.
- memRefDescriptor = rewriter.create<LLVM::InsertValueOp>(
- loc, structType, memRefDescriptor,
- createIndexConstant(rewriter, loc, offset),
- rewriter.getI64ArrayAttr(
- LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
+ memRefDescriptor.setOffset(rewriter, loc,
+ createIndexConstant(rewriter, loc, offset));
if (type.getRank() == 0)
// No size/stride descriptor in memref, return the descriptor value.
- return rewriter.replaceOp(op, memRefDescriptor);
+ return rewriter.replaceOp(op, {memRefDescriptor});
// Fields 4 and 5: Sizes and strides of the strided MemRef.
// Store all sizes in the descriptor. Only dynamic sizes are passed in as
@@ -846,18 +933,12 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
// Fill size and stride descriptors in memref.
for (auto indexedSize : llvm::enumerate(sizes)) {
int64_t index = indexedSize.index();
- memRefDescriptor = rewriter.create<LLVM::InsertValueOp>(
- loc, structType, memRefDescriptor, indexedSize.value(),
- rewriter.getI64ArrayAttr(
- {LLVMTypeConverter::kSizePosInMemRefDescriptor, index}));
- memRefDescriptor = rewriter.create<LLVM::InsertValueOp>(
- loc, structType, memRefDescriptor, strideValues[index],
- rewriter.getI64ArrayAttr(
- {LLVMTypeConverter::kStridePosInMemRefDescriptor, index}));
+ memRefDescriptor.setSize(rewriter, loc, index, indexedSize.value());
+ memRefDescriptor.setStride(rewriter, loc, index, strideValues[index]);
}
// Return the final value of the descriptor.
- rewriter.replaceOp(op, memRefDescriptor);
+ rewriter.replaceOp(op, {memRefDescriptor});
}
};
@@ -947,13 +1028,10 @@ struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
/*isVarArg=*/false));
}
- auto type = transformed.memref()->getType().cast<LLVM::LLVMType>();
- Type elementPtrType = type.getStructElementType(
- LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor);
- Value *bufferPtr = extractAllocatedMemRefElementPtr(
- rewriter, op->getLoc(), transformed.memref(), elementPtrType);
+ MemRefDescriptor memref(transformed.memref());
Value *casted = rewriter.create<LLVM::BitcastOp>(
- op->getLoc(), getVoidPtrType(), bufferPtr);
+ op->getLoc(), getVoidPtrType(),
+ memref.allocatedPtr(rewriter, op->getLoc()));
rewriter.replaceOpWithNewOp<LLVM::CallOp>(
op, ArrayRef<Type>(), rewriter.getSymbolRefAttr(freeFunc), casted);
return matchSuccess();
@@ -1003,10 +1081,8 @@ struct DimOpLowering : public LLVMLegalizationPattern<DimOp> {
int64_t index = dimOp.getIndex();
// Extract dynamic size from the memref descriptor.
if (ShapedType::isDynamic(shape[index]))
- rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(
- op, getIndexType(), transformed.memrefOrTensor(),
- rewriter.getI64ArrayAttr(
- {LLVMTypeConverter::kSizePosInMemRefDescriptor, index}));
+ rewriter.replaceOp(op, {MemRefDescriptor(transformed.memrefOrTensor())
+ .size(rewriter, op->getLoc(), index)});
else
// Use constant for static size.
rewriter.replaceOp(
@@ -1058,34 +1134,21 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> {
// This is a strided getElementPtr variant that linearizes subscripts as:
// `base_offset + index_0 * stride_0 + ... + index_n * stride_n`.
Value *getStridedElementPtr(Location loc, Type elementTypePtr,
- Value *memRefDescriptor,
- ArrayRef<Value *> indices,
+ Value *descriptor, ArrayRef<Value *> indices,
ArrayRef<int64_t> strides, int64_t offset,
ConversionPatternRewriter &rewriter) const {
- auto indexTy = this->getIndexType();
- Value *base = this->extractAlignedMemRefElementPtr(
- rewriter, loc, memRefDescriptor, elementTypePtr);
- Value *offsetValue =
- offset == MemRefType::getDynamicStrideOrOffset()
- ? rewriter.create<LLVM::ExtractValueOp>(
- loc, indexTy, memRefDescriptor,
- rewriter.getI64ArrayAttr(
- LLVMTypeConverter::kOffsetPosInMemRefDescriptor))
- : this->createIndexConstant(rewriter, loc, offset);
+ MemRefDescriptor memRefDescriptor(descriptor);
+
+ Value *base = memRefDescriptor.alignedPtr(rewriter, loc);
+ Value *offsetValue = offset == MemRefType::getDynamicStrideOrOffset()
+ ? memRefDescriptor.offset(rewriter, loc)
+ : this->createIndexConstant(rewriter, loc, offset);
+
for (int i = 0, e = indices.size(); i < e; ++i) {
- Value *stride;
- if (strides[i] != MemRefType::getDynamicStrideOrOffset()) {
- // Use static stride.
- auto attr =
- rewriter.getIntegerAttr(rewriter.getIndexType(), strides[i]);
- stride = rewriter.create<LLVM::ConstantOp>(loc, indexTy, attr);
- } else {
- // Use dynamic stride.
- stride = rewriter.create<LLVM::ExtractValueOp>(
- loc, indexTy, memRefDescriptor,
- rewriter.getI64ArrayAttr(
- {LLVMTypeConverter::kStridePosInMemRefDescriptor, i}));
- }
+ Value *stride =
+ strides[i] == MemRefType::getDynamicStrideOrOffset()
+ ? memRefDescriptor.stride(rewriter, loc, i)
+ : this->createIndexConstant(rewriter, loc, strides[i]);
Value *additionalOffset =
rewriter.create<LLVM::MulOp>(loc, indices[i], stride);
offsetValue =
@@ -1452,74 +1515,45 @@ struct SubViewOpLowering : public LLVMLegalizationPattern<SubViewOp> {
return matchFailure();
// Create the descriptor.
- Value *desc = rewriter.create<LLVM::UndefOp>(loc, targetDescTy);
+ MemRefDescriptor sourceMemRef(adaptor.source());
+ auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
// Copy the buffer pointer from the old descriptor to the new one.
- Value *sourceDescriptor = adaptor.source();
- Value *extracted = rewriter.create<LLVM::ExtractValueOp>(
- loc, sourceElementTy.getPointerTo(), sourceDescriptor,
- rewriter.getI64ArrayAttr(
- LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor));
+ Value *extracted = sourceMemRef.allocatedPtr(rewriter, loc);
Value *bitcastPtr = rewriter.create<LLVM::BitcastOp>(
loc, targetElementTy.getPointerTo(), extracted);
- desc = rewriter.create<LLVM::InsertValueOp>(
- loc, desc, bitcastPtr,
- rewriter.getI64ArrayAttr(
- LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor));
- extracted = rewriter.create<LLVM::ExtractValueOp>(
- loc, sourceElementTy.getPointerTo(), sourceDescriptor,
- rewriter.getI64ArrayAttr(
- LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor));
+ targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
+
+ extracted = sourceMemRef.alignedPtr(rewriter, loc);
bitcastPtr = rewriter.create<LLVM::BitcastOp>(
loc, targetElementTy.getPointerTo(), extracted);
- desc = rewriter.create<LLVM::InsertValueOp>(
- loc, desc, bitcastPtr,
- rewriter.getI64ArrayAttr(
- LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor));
+ targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
// Extract strides needed to compute offset.
SmallVector<Value *, 4> strideValues;
strideValues.reserve(viewMemRefType.getRank());
- for (int i = 0, e = viewMemRefType.getRank(); i < e; ++i) {
- strideValues.push_back(rewriter.create<LLVM::ExtractValueOp>(
- loc, getIndexType(), sourceDescriptor,
- rewriter.getI64ArrayAttr(
- {LLVMTypeConverter::kStridePosInMemRefDescriptor, i})));
- }
+ for (int i = 0, e = viewMemRefType.getRank(); i < e; ++i)
+ strideValues.push_back(sourceMemRef.stride(rewriter, loc, i));
// Offset.
- Value *baseOffset = rewriter.create<LLVM::ExtractValueOp>(
- loc, getIndexType(), sourceDescriptor,
- rewriter.getI64ArrayAttr(
- LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
+ Value *baseOffset = sourceMemRef.offset(rewriter, loc);
for (int i = 0, e = viewMemRefType.getRank(); i < e; ++i) {
Value *min = adaptor.offsets()[i];
baseOffset = rewriter.create<LLVM::AddOp>(
loc, baseOffset,
rewriter.create<LLVM::MulOp>(loc, min, strideValues[i]));
}
- desc = rewriter.create<LLVM::InsertValueOp>(
- loc, desc, baseOffset,
- rewriter.getI64ArrayAttr(
- LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
+ targetMemRef.setOffset(rewriter, loc, baseOffset);
// Update sizes and strides.
for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
- // Update size.
- desc = rewriter.create<LLVM::InsertValueOp>(
- loc, desc, adaptor.sizes()[i],
- rewriter.getI64ArrayAttr(
- {LLVMTypeConverter::kSizePosInMemRefDescriptor, i}));
- // Update stride.
- desc = rewriter.create<LLVM::InsertValueOp>(
- loc, desc,
- rewriter.create<LLVM::MulOp>(loc, adaptor.strides()[i],
- strideValues[i]),
- rewriter.getI64ArrayAttr(
- {LLVMTypeConverter::kStridePosInMemRefDescriptor, i}));
+ targetMemRef.setSize(rewriter, loc, i, adaptor.sizes()[i]);
+ targetMemRef.setStride(rewriter, loc, i,
+ rewriter.create<LLVM::MulOp>(
+ loc, adaptor.strides()[i], strideValues[i]));
}
- rewriter.replaceOp(op, desc);
+ rewriter.replaceOp(op, {targetMemRef});
return matchSuccess();
}
};
@@ -1571,10 +1605,6 @@ struct ViewOpLowering : public LLVMLegalizationPattern<ViewOp> {
auto loc = op->getLoc();
auto viewOp = cast<ViewOp>(op);
ViewOpOperandAdaptor adaptor(operands);
- auto sourceMemRefType = viewOp.source()->getType().cast<MemRefType>();
- auto sourceElementTy =
- lowering.convertType(sourceMemRefType.getElementType())
- .dyn_cast<LLVM::LLVMType>();
auto viewMemRefType = viewOp.getType();
auto targetElementTy = lowering.convertType(viewMemRefType.getElementType())
@@ -1593,32 +1623,20 @@ struct ViewOpLowering : public LLVMLegalizationPattern<ViewOp> {
matchFailure();
// Create the descriptor.
- Value *desc = rewriter.create<LLVM::UndefOp>(loc, targetDescTy);
+ MemRefDescriptor sourceMemRef(adaptor.source());
+ auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
// Field 1: Copy the allocated pointer, used for malloc/free.
- Value *sourceDescriptor = adaptor.source();
- Value *extracted = rewriter.create<LLVM::ExtractValueOp>(
- loc, sourceElementTy.getPointerTo(), sourceDescriptor,
- rewriter.getI64ArrayAttr(
- LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor));
+ Value *extracted = sourceMemRef.allocatedPtr(rewriter, loc);
Value *bitcastPtr = rewriter.create<LLVM::BitcastOp>(
loc, targetElementTy.getPointerTo(), extracted);
- desc = rewriter.create<LLVM::InsertValueOp>(
- loc, desc, bitcastPtr,
- rewriter.getI64ArrayAttr(
- LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor));
+ targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
// Field 2: Copy the actual aligned pointer to payload.
- extracted = rewriter.create<LLVM::ExtractValueOp>(
- loc, sourceElementTy.getPointerTo(), sourceDescriptor,
- rewriter.getI64ArrayAttr(
- LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor));
+ extracted = sourceMemRef.alignedPtr(rewriter, loc);
bitcastPtr = rewriter.create<LLVM::BitcastOp>(
loc, targetElementTy.getPointerTo(), extracted);
- desc = rewriter.create<LLVM::InsertValueOp>(
- loc, desc, bitcastPtr,
- rewriter.getI64ArrayAttr(
- LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor));
+ targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
// Field 3: Copy the offset in aligned pointer.
unsigned numDynamicSizes = llvm::size(viewOp.getDynamicSizes());
@@ -1630,14 +1648,11 @@ struct ViewOpLowering : public LLVMLegalizationPattern<ViewOp> {
? createIndexConstant(rewriter, loc, offset)
// TODO(ntv): better adaptor.
: sizeAndOffsetOperands.back();
- desc = rewriter.create<LLVM::InsertValueOp>(
- loc, desc, baseOffset,
- rewriter.getI64ArrayAttr(
- LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
+ targetMemRef.setOffset(rewriter, loc, baseOffset);
// Early exit for 0-D corner case.
if (viewMemRefType.getRank() == 0)
- return rewriter.replaceOp(op, desc), matchSuccess();
+ return rewriter.replaceOp(op, {targetMemRef}), matchSuccess();
// Fields 4 and 5: Update sizes and strides.
if (strides.back() != 1)
@@ -1648,20 +1663,14 @@ struct ViewOpLowering : public LLVMLegalizationPattern<ViewOp> {
// Update size.
Value *size = getSize(rewriter, loc, viewMemRefType.getShape(),
sizeAndOffsetOperands, i);
- desc = rewriter.create<LLVM::InsertValueOp>(
- loc, desc, size,
- rewriter.getI64ArrayAttr(
- {LLVMTypeConverter::kSizePosInMemRefDescriptor, i}));
+ targetMemRef.setSize(rewriter, loc, i, size);
// Update stride.
stride = getStride(rewriter, loc, strides, nextSize, stride, i);
- desc = rewriter.create<LLVM::InsertValueOp>(
- loc, desc, stride,
- rewriter.getI64ArrayAttr(
- {LLVMTypeConverter::kStridePosInMemRefDescriptor, i}));
+ targetMemRef.setStride(rewriter, loc, i, stride);
nextSize = size;
}
- rewriter.replaceOp(op, desc);
+ rewriter.replaceOp(op, {targetMemRef});
return matchSuccess();
}
};
OpenPOWER on IntegriCloud