summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h3
-rw-r--r--mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp5
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp52
-rw-r--r--mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir11
4 files changed, 27 insertions, 44 deletions
diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
index f40d2cfaade..f0bf3a49fa2 100644
--- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
+++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
@@ -172,6 +172,9 @@ public:
/// Builds IR inserting the pos-th stride into the descriptor
void setStride(OpBuilder &builder, Location loc, unsigned pos, Value *stride);
+ /// Returns the (LLVM) type this descriptor points to.
+ LLVM::LLVMType getElementType();
+
/*implicit*/ operator Value *() { return value; }
private:
diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
index 570b6c4bcf2..e0edb0bc047 100644
--- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
@@ -344,6 +344,11 @@ void MemRefDescriptor::setPtr(OpBuilder &builder, Location loc, unsigned pos,
builder.getI64ArrayAttr(pos));
}
+LLVM::LLVMType MemRefDescriptor::getElementType() {
+ return value->getType().cast<LLVM::LLVMType>().getStructElementType(
+ LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor);
+}
+
namespace {
// Base class for Standard to LLVM IR op conversions. Matches the Op type
// provided as template argument. Carries a reference to the LLVM dialect in
diff --git a/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp
index 21bcdc9a6db..5bda8b3fd5b 100644
--- a/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp
@@ -177,22 +177,17 @@ public:
!targetMemRefType.hasStaticShape())
return matchFailure();
- Value *sourceMemRef = operands[0];
auto llvmSourceDescriptorTy =
- sourceMemRef->getType().dyn_cast<LLVM::LLVMType>();
+ operands[0]->getType().dyn_cast<LLVM::LLVMType>();
if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy())
return matchFailure();
+ MemRefDescriptor sourceMemRef(operands[0]);
auto llvmTargetDescriptorTy = lowering.convertType(targetMemRefType)
.dyn_cast_or_null<LLVM::LLVMType>();
if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
return matchFailure();
- Type llvmSourceElementTy = llvmSourceDescriptorTy.getStructElementType(
- LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor);
- Type llvmTargetElementTy = llvmTargetDescriptorTy.getStructElementType(
- LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor);
-
int64_t offset;
SmallVector<int64_t, 4> strides;
auto successStrides =
@@ -214,55 +209,36 @@ public:
auto int64Ty = LLVM::LLVMType::getInt64Ty(lowering.getDialect());
// Create descriptor.
- Value *desc = rewriter.create<LLVM::UndefOp>(loc, llvmTargetDescriptorTy);
+ auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
+ Type llvmTargetElementTy = desc.getElementType();
// Set allocated ptr.
- Value *allocated = rewriter.create<LLVM::ExtractValueOp>(
- loc, llvmSourceElementTy, sourceMemRef,
- rewriter.getIndexArrayAttr(
- LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor));
+ Value *allocated = sourceMemRef.allocatedPtr(rewriter, loc);
allocated =
rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
- desc = rewriter.create<LLVM::InsertValueOp>(
- op->getLoc(), llvmTargetDescriptorTy, desc, allocated,
- rewriter.getIndexArrayAttr(
- LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor));
- // Set ptr.
- Value *ptr = rewriter.create<LLVM::ExtractValueOp>(
- loc, llvmSourceElementTy, sourceMemRef,
- rewriter.getIndexArrayAttr(
- LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor));
+ desc.setAllocatedPtr(rewriter, loc, allocated);
+ // Set aligned ptr.
+ Value *ptr = sourceMemRef.alignedPtr(rewriter, loc);
ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
- desc = rewriter.create<LLVM::InsertValueOp>(
- op->getLoc(), llvmTargetDescriptorTy, desc, ptr,
- rewriter.getIndexArrayAttr(
- LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor));
+ desc.setAlignedPtr(rewriter, loc, ptr);
// Fill offset 0.
auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
- desc = rewriter.create<LLVM::InsertValueOp>(
- op->getLoc(), llvmTargetDescriptorTy, desc, zero,
- rewriter.getIndexArrayAttr(
- LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
+ desc.setOffset(rewriter, loc, zero);
+
// Fill size and stride descriptors in memref.
for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) {
int64_t index = indexedSize.index();
auto sizeAttr =
rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
- desc = rewriter.create<LLVM::InsertValueOp>(
- op->getLoc(), llvmTargetDescriptorTy, desc, size,
- rewriter.getI64ArrayAttr(
- {LLVMTypeConverter::kSizePosInMemRefDescriptor, index}));
+ desc.setSize(rewriter, loc, index, size);
auto strideAttr =
rewriter.getIntegerAttr(rewriter.getIndexType(), strides[index]);
auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
- desc = rewriter.create<LLVM::InsertValueOp>(
- op->getLoc(), llvmTargetDescriptorTy, desc, stride,
- rewriter.getI64ArrayAttr(
- {LLVMTypeConverter::kStridePosInMemRefDescriptor, index}));
+ desc.setStride(rewriter, loc, index, stride);
}
- rewriter.replaceOp(op, desc);
+ rewriter.replaceOp(op, {desc});
return matchSuccess();
}
};
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index ff07f52cf23..6c5e8079ee1 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -54,12 +54,11 @@ func @vector_type_cast(%arg0: memref<8x8x8xf32>) -> memref<vector<8x8x8xf32>> {
}
// CHECK-LABEL: vector_type_cast
// CHECK: llvm.mlir.undef : !llvm<"{ [8 x [8 x <8 x float>]]*, [8 x [8 x <8 x float>]]*, i64 }">
-// CHECK: %[[allocated:.*]] = llvm.extractvalue {{.*}}[0 : index] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
+// CHECK: %[[allocated:.*]] = llvm.extractvalue {{.*}}[0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
// CHECK: %[[allocatedBit:.*]] = llvm.bitcast %[[allocated]] : !llvm<"float*"> to !llvm<"[8 x [8 x <8 x float>]]*">
-// CHECK: llvm.insertvalue %[[allocatedBit]], {{.*}}[0 : index] : !llvm<"{ [8 x [8 x <8 x float>]]*, [8 x [8 x <8 x float>]]*, i64 }">
-// CHECK: %[[aligned:.*]] = llvm.extractvalue {{.*}}[1 : index] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
+// CHECK: llvm.insertvalue %[[allocatedBit]], {{.*}}[0] : !llvm<"{ [8 x [8 x <8 x float>]]*, [8 x [8 x <8 x float>]]*, i64 }">
+// CHECK: %[[aligned:.*]] = llvm.extractvalue {{.*}}[1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
// CHECK: %[[alignedBit:.*]] = llvm.bitcast %[[aligned]] : !llvm<"float*"> to !llvm<"[8 x [8 x <8 x float>]]*">
-// CHECK: llvm.insertvalue %[[alignedBit]], {{.*}}[1 : index] : !llvm<"{ [8 x [8 x <8 x float>]]*, [8 x [8 x <8 x float>]]*, i64 }">
+// CHECK: llvm.insertvalue %[[alignedBit]], {{.*}}[1] : !llvm<"{ [8 x [8 x <8 x float>]]*, [8 x [8 x <8 x float>]]*, i64 }">
// CHECK: llvm.mlir.constant(0 : index
-// CHECK: llvm.insertvalue {{.*}}[2 : index] : !llvm<"{ [8 x [8 x <8 x float>]]*, [8 x [8 x <8 x float>]]*, i64 }">
-
+// CHECK: llvm.insertvalue {{.*}}[2] : !llvm<"{ [8 x [8 x <8 x float>]]*, [8 x [8 x <8 x float>]]*, i64 }">
OpenPOWER on IntegriCloud