diff options
author | Alex Zinenko <zinenko@google.com> | 2019-11-14 09:05:11 -0800 |
---|---|---|
committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-11-14 09:05:42 -0800 |
commit | bf5916e7a49b4f8279ee38c0fd20154a101ff026 (patch) | |
tree | 7c6e2fb2ff93ef2497e9ff90bafa0a2c0c785ab5 /mlir/lib/Conversion/VectorToLLVM | |
parent | 62d5b1de45298d0ea3a1c7135555ca83cfa57353 (diff) | |
download | bcm5719-llvm-bf5916e7a49b4f8279ee38c0fd20154a101ff026.tar.gz bcm5719-llvm-bf5916e7a49b4f8279ee38c0fd20154a101ff026.zip |
Use MemRefDescriptor in Vector-to-LLVM convresion
Following up on the consolidation of MemRef descriptor conversion, update
Vector-to-LLVM conversion to use the helper class that abstracts away the
implementation details of the MemRef descriptor. This also makes the types of
the attributes in emitted llvm.insert/extractelement operations consistently
i64 instead of a mix of index and i64.
PiperOrigin-RevId: 280441451
Diffstat (limited to 'mlir/lib/Conversion/VectorToLLVM')
-rw-r--r-- | mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp | 52 |
1 files changed, 14 insertions, 38 deletions
diff --git a/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp index 21bcdc9a6db..5bda8b3fd5b 100644 --- a/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp @@ -177,22 +177,17 @@ public: !targetMemRefType.hasStaticShape()) return matchFailure(); - Value *sourceMemRef = operands[0]; auto llvmSourceDescriptorTy = - sourceMemRef->getType().dyn_cast<LLVM::LLVMType>(); + operands[0]->getType().dyn_cast<LLVM::LLVMType>(); if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy()) return matchFailure(); + MemRefDescriptor sourceMemRef(operands[0]); auto llvmTargetDescriptorTy = lowering.convertType(targetMemRefType) .dyn_cast_or_null<LLVM::LLVMType>(); if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy()) return matchFailure(); - Type llvmSourceElementTy = llvmSourceDescriptorTy.getStructElementType( - LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor); - Type llvmTargetElementTy = llvmTargetDescriptorTy.getStructElementType( - LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor); - int64_t offset; SmallVector<int64_t, 4> strides; auto successStrides = @@ -214,55 +209,36 @@ public: auto int64Ty = LLVM::LLVMType::getInt64Ty(lowering.getDialect()); // Create descriptor. - Value *desc = rewriter.create<LLVM::UndefOp>(loc, llvmTargetDescriptorTy); + auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); + Type llvmTargetElementTy = desc.getElementType(); // Set allocated ptr. - Value *allocated = rewriter.create<LLVM::ExtractValueOp>( - loc, llvmSourceElementTy, sourceMemRef, - rewriter.getIndexArrayAttr( - LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor)); + Value *allocated = sourceMemRef.allocatedPtr(rewriter, loc); allocated = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated); - desc = rewriter.create<LLVM::InsertValueOp>( - op->getLoc(), llvmTargetDescriptorTy, desc, allocated, - rewriter.getIndexArrayAttr( - LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor)); - // Set ptr. - Value *ptr = rewriter.create<LLVM::ExtractValueOp>( - loc, llvmSourceElementTy, sourceMemRef, - rewriter.getIndexArrayAttr( - LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor)); + desc.setAllocatedPtr(rewriter, loc, allocated); + // Set aligned ptr. + Value *ptr = sourceMemRef.alignedPtr(rewriter, loc); ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr); - desc = rewriter.create<LLVM::InsertValueOp>( - op->getLoc(), llvmTargetDescriptorTy, desc, ptr, - rewriter.getIndexArrayAttr( - LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor)); + desc.setAlignedPtr(rewriter, loc, ptr); // Fill offset 0. auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0); auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr); - desc = rewriter.create<LLVM::InsertValueOp>( - op->getLoc(), llvmTargetDescriptorTy, desc, zero, - rewriter.getIndexArrayAttr( - LLVMTypeConverter::kOffsetPosInMemRefDescriptor)); + desc.setOffset(rewriter, loc, zero); + // Fill size and stride descriptors in memref. for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) { int64_t index = indexedSize.index(); auto sizeAttr = rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value()); auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr); - desc = rewriter.create<LLVM::InsertValueOp>( - op->getLoc(), llvmTargetDescriptorTy, desc, size, - rewriter.getI64ArrayAttr( - {LLVMTypeConverter::kSizePosInMemRefDescriptor, index})); + desc.setSize(rewriter, loc, index, size); auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(), strides[index]); auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr); - desc = rewriter.create<LLVM::InsertValueOp>( - op->getLoc(), llvmTargetDescriptorTy, desc, stride, - rewriter.getI64ArrayAttr( - {LLVMTypeConverter::kStridePosInMemRefDescriptor, index})); + desc.setStride(rewriter, loc, index, stride); } - rewriter.replaceOp(op, desc); + rewriter.replaceOp(op, {desc}); return matchSuccess(); } }; |