diff options
4 files changed, 27 insertions, 44 deletions
diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h index f40d2cfaade..f0bf3a49fa2 100644 --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h @@ -172,6 +172,9 @@ public: /// Builds IR inserting the pos-th stride into the descriptor void setStride(OpBuilder &builder, Location loc, unsigned pos, Value *stride); + /// Returns the (LLVM) type this descriptor points to. + LLVM::LLVMType getElementType(); + /*implicit*/ operator Value *() { return value; } private: diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp index 570b6c4bcf2..e0edb0bc047 100644 --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -344,6 +344,11 @@ void MemRefDescriptor::setPtr(OpBuilder &builder, Location loc, unsigned pos, builder.getI64ArrayAttr(pos)); } +LLVM::LLVMType MemRefDescriptor::getElementType() { + return value->getType().cast<LLVM::LLVMType>().getStructElementType( + LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor); +} + namespace { // Base class for Standard to LLVM IR op conversions. Matches the Op type // provided as template argument. Carries a reference to the LLVM dialect in diff --git a/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp index 21bcdc9a6db..5bda8b3fd5b 100644 --- a/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp @@ -177,22 +177,17 @@ public: !targetMemRefType.hasStaticShape()) return matchFailure(); - Value *sourceMemRef = operands[0]; auto llvmSourceDescriptorTy = - sourceMemRef->getType().dyn_cast<LLVM::LLVMType>(); + operands[0]->getType().dyn_cast<LLVM::LLVMType>(); if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy()) return matchFailure(); + MemRefDescriptor sourceMemRef(operands[0]); auto llvmTargetDescriptorTy = lowering.convertType(targetMemRefType) .dyn_cast_or_null<LLVM::LLVMType>(); if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy()) return matchFailure(); - Type llvmSourceElementTy = llvmSourceDescriptorTy.getStructElementType( - LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor); - Type llvmTargetElementTy = llvmTargetDescriptorTy.getStructElementType( - LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor); - int64_t offset; SmallVector<int64_t, 4> strides; auto successStrides = @@ -214,55 +209,36 @@ public: auto int64Ty = LLVM::LLVMType::getInt64Ty(lowering.getDialect()); // Create descriptor. - Value *desc = rewriter.create<LLVM::UndefOp>(loc, llvmTargetDescriptorTy); + auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); + Type llvmTargetElementTy = desc.getElementType(); // Set allocated ptr. - Value *allocated = rewriter.create<LLVM::ExtractValueOp>( - loc, llvmSourceElementTy, sourceMemRef, - rewriter.getIndexArrayAttr( - LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor)); + Value *allocated = sourceMemRef.allocatedPtr(rewriter, loc); allocated = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated); - desc = rewriter.create<LLVM::InsertValueOp>( - op->getLoc(), llvmTargetDescriptorTy, desc, allocated, - rewriter.getIndexArrayAttr( - LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor)); - // Set ptr. - Value *ptr = rewriter.create<LLVM::ExtractValueOp>( - loc, llvmSourceElementTy, sourceMemRef, - rewriter.getIndexArrayAttr( - LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor)); + desc.setAllocatedPtr(rewriter, loc, allocated); + // Set aligned ptr. + Value *ptr = sourceMemRef.alignedPtr(rewriter, loc); ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr); - desc = rewriter.create<LLVM::InsertValueOp>( - op->getLoc(), llvmTargetDescriptorTy, desc, ptr, - rewriter.getIndexArrayAttr( - LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor)); + desc.setAlignedPtr(rewriter, loc, ptr); // Fill offset 0. auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0); auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr); - desc = rewriter.create<LLVM::InsertValueOp>( - op->getLoc(), llvmTargetDescriptorTy, desc, zero, - rewriter.getIndexArrayAttr( - LLVMTypeConverter::kOffsetPosInMemRefDescriptor)); + desc.setOffset(rewriter, loc, zero); + // Fill size and stride descriptors in memref. for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) { int64_t index = indexedSize.index(); auto sizeAttr = rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value()); auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr); - desc = rewriter.create<LLVM::InsertValueOp>( - op->getLoc(), llvmTargetDescriptorTy, desc, size, - rewriter.getI64ArrayAttr( - {LLVMTypeConverter::kSizePosInMemRefDescriptor, index})); + desc.setSize(rewriter, loc, index, size); auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(), strides[index]); auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr); - desc = rewriter.create<LLVM::InsertValueOp>( - op->getLoc(), llvmTargetDescriptorTy, desc, stride, - rewriter.getI64ArrayAttr( - {LLVMTypeConverter::kStridePosInMemRefDescriptor, index})); + desc.setStride(rewriter, loc, index, stride); } - rewriter.replaceOp(op, desc); + rewriter.replaceOp(op, {desc}); return matchSuccess(); } }; diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index ff07f52cf23..6c5e8079ee1 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -54,12 +54,11 @@ func @vector_type_cast(%arg0: memref<8x8x8xf32>) -> memref<vector<8x8x8xf32>> { } // CHECK-LABEL: vector_type_cast // CHECK: llvm.mlir.undef : !llvm<"{ [8 x [8 x <8 x float>]]*, [8 x [8 x <8 x float>]]*, i64 }"> -// CHECK: %[[allocated:.*]] = llvm.extractvalue {{.*}}[0 : index] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: %[[allocated:.*]] = llvm.extractvalue {{.*}}[0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> // CHECK: %[[allocatedBit:.*]] = llvm.bitcast %[[allocated]] : !llvm<"float*"> to !llvm<"[8 x [8 x <8 x float>]]*"> -// CHECK: llvm.insertvalue %[[allocatedBit]], {{.*}}[0 : index] : !llvm<"{ [8 x [8 x <8 x float>]]*, [8 x [8 x <8 x float>]]*, i64 }"> -// CHECK: %[[aligned:.*]] = llvm.extractvalue {{.*}}[1 : index] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.insertvalue %[[allocatedBit]], {{.*}}[0] : !llvm<"{ [8 x [8 x <8 x float>]]*, [8 x [8 x <8 x float>]]*, i64 }"> +// CHECK: %[[aligned:.*]] = llvm.extractvalue {{.*}}[1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> // CHECK: %[[alignedBit:.*]] = llvm.bitcast %[[aligned]] : !llvm<"float*"> to !llvm<"[8 x [8 x <8 x float>]]*"> -// CHECK: llvm.insertvalue %[[alignedBit]], {{.*}}[1 : index] : !llvm<"{ [8 x [8 x <8 x float>]]*, [8 x [8 x <8 x float>]]*, i64 }"> +// CHECK: llvm.insertvalue %[[alignedBit]], {{.*}}[1] : !llvm<"{ [8 x [8 x <8 x float>]]*, [8 x [8 x <8 x float>]]*, i64 }"> // CHECK: llvm.mlir.constant(0 : index -// CHECK: llvm.insertvalue {{.*}}[2 : index] : !llvm<"{ [8 x [8 x <8 x float>]]*, [8 x [8 x <8 x float>]]*, i64 }"> - +// CHECK: llvm.insertvalue {{.*}}[2] : !llvm<"{ [8 x [8 x <8 x float>]]*, [8 x [8 x <8 x float>]]*, i64 }"> |