diff options
Diffstat (limited to 'mlir/lib/Conversion')
-rw-r--r-- | mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp | 5 | ||||
-rw-r--r-- | mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp | 52 |
2 files changed, 19 insertions, 38 deletions
diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp index 570b6c4bcf2..e0edb0bc047 100644 --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -344,6 +344,11 @@ void MemRefDescriptor::setPtr(OpBuilder &builder, Location loc, unsigned pos, builder.getI64ArrayAttr(pos)); } +LLVM::LLVMType MemRefDescriptor::getElementType() { + return value->getType().cast<LLVM::LLVMType>().getStructElementType( + LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor); +} + namespace { // Base class for Standard to LLVM IR op conversions. Matches the Op type // provided as template argument. Carries a reference to the LLVM dialect in diff --git a/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp index 21bcdc9a6db..5bda8b3fd5b 100644 --- a/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp @@ -177,22 +177,17 @@ public: !targetMemRefType.hasStaticShape()) return matchFailure(); - Value *sourceMemRef = operands[0]; auto llvmSourceDescriptorTy = - sourceMemRef->getType().dyn_cast<LLVM::LLVMType>(); + operands[0]->getType().dyn_cast<LLVM::LLVMType>(); if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy()) return matchFailure(); + MemRefDescriptor sourceMemRef(operands[0]); auto llvmTargetDescriptorTy = lowering.convertType(targetMemRefType) .dyn_cast_or_null<LLVM::LLVMType>(); if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy()) return matchFailure(); - Type llvmSourceElementTy = llvmSourceDescriptorTy.getStructElementType( - LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor); - Type llvmTargetElementTy = llvmTargetDescriptorTy.getStructElementType( - LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor); - int64_t offset; SmallVector<int64_t, 4> strides; auto successStrides = @@ -214,55 +209,36 @@ public: auto int64Ty = LLVM::LLVMType::getInt64Ty(lowering.getDialect()); // Create descriptor. - Value *desc = rewriter.create<LLVM::UndefOp>(loc, llvmTargetDescriptorTy); + auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); + Type llvmTargetElementTy = desc.getElementType(); // Set allocated ptr. - Value *allocated = rewriter.create<LLVM::ExtractValueOp>( - loc, llvmSourceElementTy, sourceMemRef, - rewriter.getIndexArrayAttr( - LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor)); + Value *allocated = sourceMemRef.allocatedPtr(rewriter, loc); allocated = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated); - desc = rewriter.create<LLVM::InsertValueOp>( - op->getLoc(), llvmTargetDescriptorTy, desc, allocated, - rewriter.getIndexArrayAttr( - LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor)); - // Set ptr. - Value *ptr = rewriter.create<LLVM::ExtractValueOp>( - loc, llvmSourceElementTy, sourceMemRef, - rewriter.getIndexArrayAttr( - LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor)); + desc.setAllocatedPtr(rewriter, loc, allocated); + // Set aligned ptr. + Value *ptr = sourceMemRef.alignedPtr(rewriter, loc); ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr); - desc = rewriter.create<LLVM::InsertValueOp>( - op->getLoc(), llvmTargetDescriptorTy, desc, ptr, - rewriter.getIndexArrayAttr( - LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor)); + desc.setAlignedPtr(rewriter, loc, ptr); // Fill offset 0. auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0); auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr); - desc = rewriter.create<LLVM::InsertValueOp>( - op->getLoc(), llvmTargetDescriptorTy, desc, zero, - rewriter.getIndexArrayAttr( - LLVMTypeConverter::kOffsetPosInMemRefDescriptor)); + desc.setOffset(rewriter, loc, zero); + // Fill size and stride descriptors in memref. for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) { int64_t index = indexedSize.index(); auto sizeAttr = rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value()); auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr); - desc = rewriter.create<LLVM::InsertValueOp>( - op->getLoc(), llvmTargetDescriptorTy, desc, size, - rewriter.getI64ArrayAttr( - {LLVMTypeConverter::kSizePosInMemRefDescriptor, index})); + desc.setSize(rewriter, loc, index, size); auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(), strides[index]); auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr); - desc = rewriter.create<LLVM::InsertValueOp>( - op->getLoc(), llvmTargetDescriptorTy, desc, stride, - rewriter.getI64ArrayAttr( - {LLVMTypeConverter::kStridePosInMemRefDescriptor, index})); + desc.setStride(rewriter, loc, index, stride); } - rewriter.replaceOp(op, desc); + rewriter.replaceOp(op, {desc}); return matchSuccess(); } }; |