summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Conversion/VectorToLLVM
diff options
context:
space:
mode:
authorAlex Zinenko <zinenko@google.com>2019-11-14 09:05:11 -0800
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-11-14 09:05:42 -0800
commitbf5916e7a49b4f8279ee38c0fd20154a101ff026 (patch)
tree7c6e2fb2ff93ef2497e9ff90bafa0a2c0c785ab5 /mlir/lib/Conversion/VectorToLLVM
parent62d5b1de45298d0ea3a1c7135555ca83cfa57353 (diff)
downloadbcm5719-llvm-bf5916e7a49b4f8279ee38c0fd20154a101ff026.tar.gz
bcm5719-llvm-bf5916e7a49b4f8279ee38c0fd20154a101ff026.zip
Use MemRefDescriptor in Vector-to-LLVM convresion
Following up on the consolidation of MemRef descriptor conversion, update Vector-to-LLVM conversion to use the helper class that abstracts away the implementation details of the MemRef descriptor. This also makes the types of the attributes in emitted llvm.insert/extractelement operations consistently i64 instead of a mix of index and i64. PiperOrigin-RevId: 280441451
Diffstat (limited to 'mlir/lib/Conversion/VectorToLLVM')
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp52
1 files changed, 14 insertions, 38 deletions
diff --git a/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp
index 21bcdc9a6db..5bda8b3fd5b 100644
--- a/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp
@@ -177,22 +177,17 @@ public:
!targetMemRefType.hasStaticShape())
return matchFailure();
- Value *sourceMemRef = operands[0];
auto llvmSourceDescriptorTy =
- sourceMemRef->getType().dyn_cast<LLVM::LLVMType>();
+ operands[0]->getType().dyn_cast<LLVM::LLVMType>();
if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy())
return matchFailure();
+ MemRefDescriptor sourceMemRef(operands[0]);
auto llvmTargetDescriptorTy = lowering.convertType(targetMemRefType)
.dyn_cast_or_null<LLVM::LLVMType>();
if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
return matchFailure();
- Type llvmSourceElementTy = llvmSourceDescriptorTy.getStructElementType(
- LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor);
- Type llvmTargetElementTy = llvmTargetDescriptorTy.getStructElementType(
- LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor);
-
int64_t offset;
SmallVector<int64_t, 4> strides;
auto successStrides =
@@ -214,55 +209,36 @@ public:
auto int64Ty = LLVM::LLVMType::getInt64Ty(lowering.getDialect());
// Create descriptor.
- Value *desc = rewriter.create<LLVM::UndefOp>(loc, llvmTargetDescriptorTy);
+ auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
+ Type llvmTargetElementTy = desc.getElementType();
// Set allocated ptr.
- Value *allocated = rewriter.create<LLVM::ExtractValueOp>(
- loc, llvmSourceElementTy, sourceMemRef,
- rewriter.getIndexArrayAttr(
- LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor));
+ Value *allocated = sourceMemRef.allocatedPtr(rewriter, loc);
allocated =
rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
- desc = rewriter.create<LLVM::InsertValueOp>(
- op->getLoc(), llvmTargetDescriptorTy, desc, allocated,
- rewriter.getIndexArrayAttr(
- LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor));
- // Set ptr.
- Value *ptr = rewriter.create<LLVM::ExtractValueOp>(
- loc, llvmSourceElementTy, sourceMemRef,
- rewriter.getIndexArrayAttr(
- LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor));
+ desc.setAllocatedPtr(rewriter, loc, allocated);
+ // Set aligned ptr.
+ Value *ptr = sourceMemRef.alignedPtr(rewriter, loc);
ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
- desc = rewriter.create<LLVM::InsertValueOp>(
- op->getLoc(), llvmTargetDescriptorTy, desc, ptr,
- rewriter.getIndexArrayAttr(
- LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor));
+ desc.setAlignedPtr(rewriter, loc, ptr);
// Fill offset 0.
auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
- desc = rewriter.create<LLVM::InsertValueOp>(
- op->getLoc(), llvmTargetDescriptorTy, desc, zero,
- rewriter.getIndexArrayAttr(
- LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
+ desc.setOffset(rewriter, loc, zero);
+
// Fill size and stride descriptors in memref.
for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) {
int64_t index = indexedSize.index();
auto sizeAttr =
rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
- desc = rewriter.create<LLVM::InsertValueOp>(
- op->getLoc(), llvmTargetDescriptorTy, desc, size,
- rewriter.getI64ArrayAttr(
- {LLVMTypeConverter::kSizePosInMemRefDescriptor, index}));
+ desc.setSize(rewriter, loc, index, size);
auto strideAttr =
rewriter.getIntegerAttr(rewriter.getIndexType(), strides[index]);
auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
- desc = rewriter.create<LLVM::InsertValueOp>(
- op->getLoc(), llvmTargetDescriptorTy, desc, stride,
- rewriter.getI64ArrayAttr(
- {LLVMTypeConverter::kStridePosInMemRefDescriptor, index}));
+ desc.setStride(rewriter, loc, index, stride);
}
- rewriter.replaceOp(op, desc);
+ rewriter.replaceOp(op, {desc});
return matchSuccess();
}
};
OpenPOWER on IntegriCloud